In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import random
import scipy

from PIL import Image
from os import getcwd

from tqdm import tqdm

%matplotlib inline

input_dir = getcwd() + '/img/input_samples/'

In [None]:
def freq_filter(x_freq, y_freq, factor=2.4):
    eps = 10 ** -8
    x, y = np.meshgrid(x_freq, y_freq)
    f = np.hypot(x, y)
    f = f ** factor + eps
    return normalize(1 / f)


def freq_filter_2d(x_freq, y_freq, x_aspect=1, y_aspect=1, factor=1):
    x, y = np.meshgrid(x_freq, y_freq)
    f = np.hypot(x / x_aspect, y / y_aspect)
    f = 1 / (1 + np.abs(f))
    f = f ** (factor + 1)
    return normalize(f)


def freq_filter_1d(x_freq, factor=1):
    f = 1 / (1 + np.abs(x_freq))
    f = f ** factor
    return f


def freq_filter_1d_a(x_freq, factor=1):
    f = 1 / np.where(x_freq == 0, 1, np.abs(x_freq))
    f = f ** factor
    f[len(x_freq) // 2] = 0
    return f


def freq_sharp_round_filter(x_freq, y_freq, radius, reverse=False):
    eps = 10 ** -8
    x, y = np.meshgrid(x_freq, y_freq)
    if reverse:
        f = np.zeros((x_size, y_size))
    else:
        f = np.ones((x_size, y_size))
    
    for i, xx in enumerate(x_freq):
        for j, yy in enumerate(y_freq):
            if xx ** 2 + yy ** 2 <= radius ** 2:
                if reverse:
                    f[i, j] = 1
                else:
                    f[i, j] = 0
    return normalize(f)


def freq_sharp_square_filter(x_freq, y_freq, width, reverse=False):
    eps = 10 ** -8
    x, y = np.meshgrid(x_freq, y_freq)
    if reverse:
        f = np.zeros((x_size, y_size))
    else:
        f = np.ones((x_size, y_size))
    
    for i, xx in enumerate(x_freq):
        for j, yy in enumerate(y_freq):
            if abs(xx) + abs(yy) <= width:
                if reverse:
                    f[i, j] = 1
                else:
                    f[i, j] = 0
    return normalize(f)


def spatial_smooth_filter(x_size, y_size, depth, horiz=True):
    values = np.linspace(0, 1, depth)
    values = 6 * values ** 5 - 15 * values ** 4 + 10 * values ** 3
    values = 1 - values
    if horiz:
        kernel = np.tile(values, (y_size, 1))
    else:
        kernel = values[:, np.newaxis] * np.ones((1, x_size))   
    return kernel


def find_ft(img):
    ft = np.fft.fft2(img)
    return np.fft.fftshift(ft)


def find_ft_1d(img):
    ft = np.fft.fft(img)
    return np.fft.fftshift(ft)


def find_ift(ft):
    ift = np.fft.ifftshift(ft)
    return np.fft.ifft2(ift)


def find_ift_1d(ft):
    ift = np.fft.ifftshift(ft)
    return np.fft.ifft(ift)


def freq_numbers_1d(size):
    if size % 2:
        return np.arange(-(size // 2), size // 2 + 1, 1) 
    else:
        return np.arange(-(size // 2), size // 2, 1) 
    

def freq_1d(size, step=1):
    freq = freq_numbers_1d(size) 
    return freq / step / size


def adjust_freq_1d(img):
    return np.append(img, abs(img[0]))


def adjust_img_1d(img):
    return np.append(img, img[0])
        
        
def normalize(arr):
    min_val = abs(np.min(arr))
    max_val = abs(np.max(arr))
    return (arr + min_val) / (min_val + max_val)


def normalize(arr):
    min_val = abs(np.min(arr))
    max_val = abs(np.max(arr))
    return (arr - 0.99 * min_val) / (max_val - 0.99 * min_val)


def normalize_img(img):
    img_min = abs(np.min(img))
    img_max = abs(np.max(img))
    img_norm = 255 * (img + img_min) / (img_min + img_max)
    return img_norm.astype(int)


def gen_cloud(x_size, y_size, factor=2.4):
    xx = np.linspace(-x_size / 2, x_size / 2, x_size)
    yy = np.linspace(-y_size / 2, y_size / 2, y_size)
    whitenoise = np.random.normal(0, 1, (y_size, x_size))
    cloud_freq = find_ft(whitenoise)  
    kernel = freq_filter(xx, yy, factor=factor)
    cloud_freq_filtered = cloud_freq * kernel
    cloud_spatial = find_ift(cloud_freq_filtered).real
    return normalize_img(cloud_spatial)


def show_images(*images, vmin=0, vmax=255, x_fig_size=10, cmap='gray', y_fig_size=10, graphs_per_row=2):
    row_num = ceil(len(images) / graphs_per_row)
    col_num = ceil(len(images) / row_num)
    
    f, axes = plt.subplots(row_num, col_num, sharey=True, figsize=(x_fig_size, y_fig_size))

    for ax, img in zip(axes.flatten(), images):
        ax.imshow(img, cmap='gray', vmin=0, vmax=255)
        
        
def make_img_transition_x(img, depth, is_dx_pos=True):
    y_size, x_size = img.shape
    additional_img = gen_cloud(x_size + depth, y_size)   
    transition_kernel = spatial_smooth_filter(x_size, y_size, depth)     
    
    new_img = np.copy(img)
    if is_dx_pos:
        new_img[:, -depth:x_size] = img[:, -depth:x_size] * transition_kernel + \
                                additional_img[:, 0:depth] * (1 - transition_kernel)
        return new_img, additional_img[:, depth:]    
    else:
        transition_kernel = np.fliplr(transition_kernel)
        new_img[:, 0:depth] = img[:, 0:depth] * transition_kernel + \
                          additional_img[:, -depth:] * (1 - transition_kernel)  
        return new_img, additional_img[:, 0:-depth]    


def make_img_transition_y(img, depth, is_dy_pos=True):
    y_size, x_size = img.shape
    additional_img = gen_cloud(x_size, y_size + depth)   
    transition_kernel = spatial_smooth_filter(x_size, y_size, depth, horiz=False)
        
    new_img = np.copy(img)
    if is_dy_pos:
        new_img[-depth:x_size, :] = img[-depth:x_size, :] * transition_kernel + \
                                additional_img[0:depth, :] * (1 - transition_kernel)
        return new_img, additional_img[depth:, :]    
    else:
        transition_kernel = np.flipud(transition_kernel)
        new_img[0:depth, :] = img[0:depth, :] * transition_kernel + \
                          additional_img[-depth:, :] * (1 - transition_kernel)  
        return new_img, additional_img[0:-depth:1, :]    

    
def freq_2d(x_freq, y_freq):
    x, y = np.meshgrid(x_freq, y_freq)
    f = np.hypot(x, y)
    return f


def lin_regression(x, y):
    # y - original img
    # x - restored img
    num = np.mean(x * y) - np.mean(x) * np.mean(y)
    denum = np.mean(x ** 2) - np.mean(x) ** 2
    a = num / denum
    b = np.mean(y) - a * np.mean(x)
    return a, b


def lin_phase(start, end, size):
    pos_freq = np.linspace(start, end, size // 2)
    neg_freq = -pos_freq[::-1]
    return np.append(neg_freq, pos_freq)

### Cumulus test

In [None]:
img = Image.open(input_dir + '2.jpg').convert('L')
img -= np.mean(img)
img_fr = find_ft(img)

x_size, y_size = img_fr.shape
xx = freq_numbers_1d(x_size)
yy = freq_numbers_1d(y_size)

magn = np.abs(img_fr)
magn = normalize(magn)
phase = np.angle(img_fr)

In [None]:
fig = plt.figure(figsize=(10, 10))

ax1 = fig.add_subplot(2, 2, 1)
ax1.imshow(img, cmap='gray')

ax2 = fig.add_subplot(2, 2, 2)
ax2.imshow(np.log(magn), cmap='gray')

ax3 = fig.add_subplot(2, 2, 3)
ax3.imshow(phase, cmap='gray')

ax4 = fig.add_subplot(2, 2, 4)
ax4.imshow(find_ift(magn * np.exp(1j * phase)).real, cmap='gray')

In [None]:
phase_fr = find_ft(phase)
phase_magn = np.abs(phase_fr)
phase_angle = np.angle(phase_fr)

fig = plt.figure(figsize=(10, 10))

ax1 = fig.add_subplot(1, 2, 1)
ax1.imshow(phase_magn, cmap='gray')

ax2 = fig.add_subplot(1, 2, 2)
ax2.imshow(phase_angle, cmap='gray')

In [None]:
ph_fr = np.sqrt(normalize(phase_magn)) * np.exp(1j * phase_angle)
ph = find_ift(ph_fr).real

plt.imshow(ph, cmap='gray')

In [None]:
fig = go.Figure()
fig.add_trace(go.Surface(x=xx, y=yy, z=np.log10(magn)))

fig.update_layout(
    scene = dict(
    zaxis = dict(range=[-5, 0],),
))

fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Surface(x=list(range(20)), y=list(range(20)), z=phase[x_size // 2 - 10: x_size // 2 + 10, y_size // 2 - 10: y_size // 2 + 10]))
# fig.add_trace(go.Surface(x=xx, y=yy, z=phase))

In [None]:
# Phase along x axis
# plt.plot(phase[x_size // 2 + 2, y_size // 2:])

# Phase along y axis
plt.plot(phase[x_size // 2:, y_size // 2 ])