### A Load Images

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np  
import os

def compute_cor(S1, S2):
    N = S1.size
    cor = 0
    for i in range(S1.shape[0]):
        for j in range(S1.shape[1]):
            cor += S1[i, j] * S2[i, j]
    cor /= N
    return cor

figures_dir  = os.path.join(".", "figures-2")
# file_name = 'hyz'
file_name = "p5"
# file_name = "pepper"
image = Image.open(os.path.join(figures_dir, f"{file_name}.jpg")).convert('L')  # 'L' means the gray scale
I_luminance = np.array(image)
print("Image size: ", I_luminance.size)
print("Image shape: ", I_luminance.shape)
H, W = I_luminance.shape


# Plotting
plt.imshow(I_luminance, cmap='gray')  # show the image
plt.axis('off')     # turn down the axis
# plt.colorbar()
# plt.savefig(os.path.join(figures_dir, f"{file_name}-luminance.jpg"),             
#             bbox_inches='tight',  # 紧贴内容裁剪
#             pad_inches=0,         # 去掉额外留白
#             )
plt.show()


### B Zero-mean

In [None]:
mean_val = np.mean(I_luminance)

I_luminance = I_luminance-mean_val

plt.imshow(I_luminance, cmap='gray')  # show the image
plt.axis('off')     # turn down the axis
plt.colorbar()
plt.show()

### C Neighboring Pixels

In [None]:
def cal_corr_d(S,d):
    width, height = S.shape
    R_S = 0

    total_N = 0
    for y in range(height):
        for x in range(width-d):
            R_S += S[y,x] * S[y,x+d]
        total_N += 1
    R_S /= total_N
    return R_S


# def cal_corr_d(S,d):
#     width, height = S.shape
#     R_S = 0

#     total_N = 0
#     for y in range(height):
#         for x in range(width):
#             if x + d < width:
#                 R_S += S[y,x] * S[y,x+d]
#                 continue
#             else:
#                 R_S += S[y,x] * S[y,x+d-width]
#             total_N += 1
#     R_S /= total_N
#     return R_S

d_list = np.arange(0,I_luminance.shape[1]-1,1)
R_S_list = np.array([cal_corr_d(I_luminance,d) for d in d_list])

plt.plot(d_list, R_S_list)
# plt.yscale('log')
plt.axhline(y=0, color='gray', linestyle='--')
# plt.axvline(x=114.5, color='r', linestyle='--')
plt.xlabel('$d$ (pixel)')
plt.ylabel('$R^S(d)$ (log scale)')
plt.title('Spatial Correlation $R^S(d)$ vs. Distance $d$')

# # find out the zero-point for R_S_list
# for i in range(1, len(R_S_list)):
#     if R_S_list[i-1] * R_S_list[i] < 0:
#         print(f"R_S crosses zero between d={d_list[i-1]} and d={d_list[i]}")
#         break

# plt.savefig(os.path.join(figures_dir, "spatial_correlation_vs_distance.pdf"),             
#             bbox_inches='tight',  # 紧贴内容裁剪
#             pad_inches=0.1,         # 去掉额外留白
#             )


In [None]:
def cal_corr_d_prime(S,d):
    width, height = S.shape
    R_S = 0

    total_N = 0
    for y in range(height):
        for x in range(width):
            if x + d < width:
                R_S += S[y,x] * S[y,x+d]
                continue
            else:
                R_S += S[y,x] * S[y,2*width - (x + d)-1]
            total_N += 1
    R_S /= total_N
    return R_S

d_list = np.arange(0,I_luminance.shape[1]-1,1)
R_S_list_prime = np.array([cal_corr_d_prime(I_luminance,d) for d in d_list])

plt.plot(d_list, R_S_list_prime)
# plt.yscale('log')
plt.axhline(y=0, color='gray', linestyle='--')
# plt.axvline(x=114.5, color='r', linestyle='--')
plt.xlabel('$d$ (pixel)')
plt.ylabel('$R^S(d)$ (log scale)')
plt.title('Spatial Correlation $R^S(d)$ vs. Distance $d$')

plt.savefig(os.path.join(figures_dir, "spatial_correlation_vs_distance_prime.pdf"),             
            bbox_inches='tight',  # 紧贴内容裁剪
            pad_inches=0.1,         # 去掉额外留白
            )

**As distance $d$ increase, the *approximate* spatial correlation $R^S(d)$ decreases very queickly.**

It can be seen that, as distance $d$ of pixels increase, the approximate spatial correlation $R^S(d)$ between them decreases very queickly, roughly exponentially. For some $d$ there are even negative correltion, which might due to the limited data sample of the image and the specific features in this image. It indicates that pixels far away from each other tend to be less correlated, which is reasonable as the features in natural images are usually locally continuous.


### D-E Fourier waves in any directions

In [None]:
F_I = np.fft.fft2(I_luminance)
F_I_shift = np.fft.fftshift(F_I)  # shift the zero-frequency component to the center
magnitude_matrix = np.abs(F_I_shift)
magnitude_log_matrix = np.log1p(magnitude_matrix)
shape_F_I = F_I_shift.shape
print("Magnitude matrix shape: ", magnitude_log_matrix.shape)

In [None]:
I_back = np.fft.ifft2(np.fft.ifftshift(F_I_shift)).real
plt.imshow(I_back, cmap='gray')  # show the image
# plt.colorbar()
plt.axis('off')     # turn down the axis
plt.title('Reconstructed Image from Inverse FFT')
# plt.savefig(os.path.join(figures_dir, "reconstructed_image_from_inverse_fft.pdf"),             
#             bbox_inches='tight',  # 紧贴内容裁剪
#             pad_inches=0.1,         # 去掉额外留白
#             )

In [None]:
# plot the magnitude spectrum of the Fourier transform

def plot_magnitude_spectrum(magnitude_log_matrix, save_fig=False, filename="requency_spectrum.pdf"):
    shape = magnitude_log_matrix.shape
    plt.figure(figsize=(8,6))
    plt.imshow(magnitude_log_matrix, extent = [-shape[1],shape[1], -shape[0], shape[0]])  # show the image
            #    xmin=-magnitude_log_matrix.shape[1],xmax=magnitude_log_matrix.shape[1], ymin=-magnitude_log_matrix.shape[0], ymax=magnitude_log_matrix.shape[0]
    plt.xlabel("$k_x$", fontsize=14)
    plt.ylabel("$k_y$", fontsize=14)
    plt.title("$\\log|\\mathcal{S}(k)|$", fontsize=14)
    plt.colorbar()
    if save_fig:
        plt.savefig(os.path.join(figures_dir, filename),             
            bbox_inches='tight',  # 紧贴内容裁剪
            pad_inches=0.1,         # 去掉额外留白
            )
    plt.show()

plot_magnitude_spectrum(magnitude_log_matrix, save_fig=False)

In [None]:
# Compute the magnitude matrix K out of k_x and k_y
# the frquency components used along x and y directions

def get_ff_coordinate_matrix(shape):
    ''' Get the frequency coordinate matrix K from the shape of the input matrix'''
    W, H = shape
    kx = np.fft.fftshift(np.fft.fftfreq(W)) 
    ky = np.fft.fftshift(np.fft.fftfreq(H)) 

    K = np.zeros((H, W))
    for i in range(H):
        for j in range(W):
            K[i, j] = np.sqrt(kx[j]**2 + ky[i]**2)
    return K


def plot_power_vs_k(magnitude_matrix, k_scale=1e7, save_fig=False, filename=None, xvline=None):

    K = get_ff_coordinate_matrix(magnitude_matrix.shape)
            
    # compute the corresponding flattened lists
    k_list = K.flatten()
    power_list = (magnitude_matrix ** 2).flatten()

    # compute the average magnitude for each unique k
    unique_k = np.unique(k_list)
    avg_powers = []
    for k in unique_k:
        indices = np.where(k_list == k)
        avg_power = np.mean(power_list[indices])
        avg_powers.append(avg_power)
    k_square_reverse = np.array([1/(k**2) for k in unique_k if k != 0])
    k_square_reverse *= k_scale  # scale for better visualization

    # plot 
    plt.scatter(k_list, power_list, s=1, alpha=0.5, color='gray', rasterized=True)
    plt.xlabel("$|k|$", fontsize=14)
    plt.ylabel("$|\\mathcal{S}(k)|^2$", fontsize=14)
    plt.yscale('log')
    plt.xscale('log')
    plt.title("$|\\mathcal{S}(k)|^2$ and $\\sim 1/k^2$ (red) vs. $|k|$", fontsize=14)
    if xvline is not None:
        plt.axvline(x=xvline, color='blue', linestyle='--')
        plt.text(xvline*1.1, max(power_list)/10, f'$k_{{low}}\\approx{xvline}$', color='blue')

    # average magnitude plot
    plt.plot(unique_k, avg_powers, color='green', label='Average Magnitude', zorder=2, alpha=0.6, linestyle='--')
    # 1/k^2 plot
    plt.plot(unique_k[1:], k_square_reverse, marker='None', color='red', linestyle='--', zorder=3)
    plt.text(unique_k[len(unique_k)//1000], max(k_square_reverse)/1000, '$\\sim 1/k^2$', color='red')

    if save_fig:
        if filename is None:
            filename = "power_vs_k.pdf"
        plt.savefig(os.path.join(figures_dir, filename),             
            bbox_inches='tight',  # 紧贴内容裁剪
            pad_inches=0.1,         # 去掉额外留白
            )
        
    plt.show()

In [None]:
plot_power_vs_k(magnitude_matrix, k_scale=1e7, save_fig=False)

### F Adding noise

In [None]:
N_max = 100
noise_matrix = np.random.randint(-N_max, N_max, I_luminance.shape)
I_noisy = I_luminance + noise_matrix
plt.imshow(I_noisy, cmap='gray')  # show the image
plt.axis('off')     # turn down the axis
# plt.colorbar()
plt.title(f'Noisy image with random noise ($N_{{max}}$={N_max})')

# plt.savefig(os.path.join(figures_dir, f"Noisy_image.pdf"),             
#             bbox_inches='tight',  # 紧贴内容裁剪
#             pad_inches=0.1,         # 去掉额外留白
#             )
plt.show()

In [None]:
F_I_noisy = np.fft.fft2(I_noisy)
F_I_shift_noisy = np.fft.fftshift(F_I_noisy)  # shift the zero-frequency component to the center
noisy_magnitude_matrix = np.abs(F_I_shift_noisy)
noisy_magnitude_log_matrix = np.log1p(noisy_magnitude_matrix)
shape_F_I = F_I_shift_noisy.shape
print("Magnitude matrix shape: ", noisy_magnitude_log_matrix.shape)

In [None]:
plot_magnitude_spectrum(noisy_magnitude_log_matrix, 
                        # save_fig=True, filename="frequency_spectrum_noisy.pdf"
                        )

In [None]:
plot_power_vs_k(noisy_magnitude_matrix, k_scale=1e7, 
                save_fig=False, filename="power_vs_k_noisy.pdf", 
                xvline=1.5e-1)

**With random noise, the frequency spectrum doesn't show clear `x' pattern along the zero frequency axes, and the power at high frequency $k$ doesn't decrease along $1/k^2$ as expected**

### G Gain control

In [None]:
# # another noisy image
# N_max = 50
# noise_matrix = np.random.randint(-N_max, N_max, I_luminance.shape)
# I_noisy = I_luminance + noise_matrix
# F_I_noisy = np.fft.fft2(I_noisy)
# F_I_shift_noisy = np.fft.fftshift(F_I_noisy)  # shift the zero-frequency component to the center

def get_gain(k, k_o=1e-8, k_low=0.2):
    return (np.abs(k) + k_o) * np.exp(- (np.abs(k)**2 / k_low**2))

def get_gain_matrix(shape, k_o=1e-8, k_low=0.2):
    K_noise = get_ff_coordinate_matrix(shape)

    get_gain_vec = np.vectorize(get_gain)
    gain_matrix = get_gain_vec(K_noise, k_o=k_o, k_low=k_low)
    return gain_matrix

def output_with_gain_control(F_I_shift_noisy,k_o=1e-8, k_low=0.2, 
                             plot_fig=True, title="Denoised Image after Frequency Filtering", 
                             save_fig=False, filename=None):

    K_noise = get_ff_coordinate_matrix(F_I_shift_noisy.shape)

    gain_matrix = get_gain_matrix(F_I_shift_noisy.shape, k_o=k_o, k_low=k_low)

    Output_F_I_shift = F_I_shift_noisy * gain_matrix # element-wise multiplication

    I_denoised = np.fft.ifft2(np.fft.ifftshift(Output_F_I_shift)).real
    if plot_fig:
        plt.imshow(I_denoised, cmap='gray')  # show the image
        plt.axis('off')     # turn down the axis
        plt.title(title)
        if save_fig:
            if filename is None:
                filename = "output_image_gain_control.pdf"
            plt.savefig(os.path.join(figures_dir, filename),             
                bbox_inches='tight',  # 紧贴内容裁剪
                pad_inches=0.1,         # 去掉额外留白
                )
        plt.show()

output_with_gain_control(F_I_shift_noisy, k_o=1e-8, k_low=0.15,
                         title="$\\mathcal{O}$ after optimal $k_{low}$",
                        #  save_fig=True,
                        #  filename="output_image_optimal_gain_control.pdf"
                         )

_ = output_with_gain_control(F_I_shift_noisy, k_o=1e-8, k_low=1,
                         title="$\\mathcal{O}$ after higher-pass",
                         save_fig=False,
                         filename="output_image_higher-pass_gain_control.pdf")
_ = output_with_gain_control(F_I_shift_noisy, k_o=1e-8, k_low=3e-2,
                         title="$\\mathcal{O}$ after lower-pass",
                         save_fig=False,
                         filename="output_image_lower-pass_gain_control.pdf")

### G-prime Power vs. k

In [None]:

k_low = 0.15    # optimal k_low where S/N=1

gain_matrix = get_gain_matrix(F_I_shift_noisy.shape, k_o=1e-8, k_low=k_low)
K_noise = get_ff_coordinate_matrix(F_I_shift_noisy.shape)

# compute the corresponding flattened lists
k_list = K_noise.flatten()
gain_list = (gain_matrix).flatten()

# compute the average magnitude for each unique k
unique_k = np.unique(k_list)
avg_gains = []
for k in unique_k:
    indices = np.where(k_list == k)
    avg_gain = np.mean(gain_list[indices])
    avg_gains.append(avg_gain)
# plot 
plt.scatter(k_list, gain_list, s=1, alpha=0.5, color='gray', rasterized=True)
plt.yscale('log')
plt.xscale('log')
plt.ylabel("$g(k)$", fontsize=14)
plt.xlabel("Frequency $|k|$", fontsize=14)
# plt.title("The optimal gain $g_k$ versus $k$")

plt.axvline(x=k_low, color='blue', linestyle='--', label='$k_{low}$')
plt.text(k_low*1.1, max(gain_list)/50, f'$k_{{low}}={k_low}$', color='blue')

# plt.savefig(os.path.join(figures_dir, "opt_gain_vs_k.pdf"),             
#             bbox_inches='tight',  # 紧贴内容裁剪
#             pad_inches=0.1,         # 去掉额外留白
#             )
plt.show()

### H Fourier transform the gain control function

In [None]:
# test, to show the heatmap of gain directly
# plt.imshow(np.log(gain_matrix), cmap='gray')
# plt.colorbar()

gain_matrix = get_gain_matrix(F_I_shift_noisy.shape, k_o=1e-8, k_low=0.2)

from matplotlib.colors import LogNorm
plt.imshow(gain_matrix, extent = [-gain_matrix.shape[1], gain_matrix.shape[1], -gain_matrix.shape[0], gain_matrix.shape[0]],
        #    norm=LogNorm(vmin=gain_matrix[gain_matrix>0].min(), 
                                     vmax=gain_matrix.max(),
                                     cmap='viridis'
                                     )
plt.title("$g(k)$ ", fontsize=14)
plt.xlabel("$k_x$", fontsize=14)
plt.ylabel("$k_y$", fontsize=14)
plt.colorbar()
# plt.savefig(os.path.join(figures_dir, "gain_matrix_heatmap.pdf"),             
#                 bbox_inches='tight',  # 紧贴内容裁剪
#                 pad_inches=0.1)         # 去掉额外留白
plt.show()

In [None]:
def plot_ifft_gain_matrix(gain_matrix, k_low, title=None, save_fig=False, filename="receptive_field_gain_matrix.pdf"):
    kernel = np.fft.ifft2(np.fft.ifftshift(gain_matrix)).real
    kernel = np.fft.fftshift(kernel)
    plt.imshow(kernel, cmap='bwr')
    plt.title(f"Spatial RF ($k_{{low}}$={k_low})")
    plt.colorbar()
    plt.axis('off')
    if save_fig:
        if filename is None:
            filename = f"ifft_gain_matrix_k_low_{k_low}.pdf"
        plt.savefig(os.path.join(figures_dir, filename),             
                bbox_inches='tight',  # 紧贴内容裁剪
                pad_inches=0.1)         # 去掉额外留白
    plt.show()


gain_matrix = get_gain_matrix(F_I_shift_noisy.shape, k_o=1e-8, k_low=0.2)
plot_ifft_gain_matrix(gain_matrix, 0.15, save_fig=False, filename="receptive_field_gain_matrix_k_low_0.15.pdf")

Different $k_{low}$

In [None]:
k_low_list = [0.02, 0.01, 0.005]
for k_low in k_low_list:
    gain_matrix = get_gain_matrix(F_I_shift_noisy.shape, k_o=1e-8, k_low=k_low)
    plot_ifft_gain_matrix(gain_matrix, k_low,save_fig=False, filename=f"receptive_field_gain_matrix_k_low_{k_low}.pdf")