In [None]:
import os
import sys 
import typing as t
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

# add src to path
sys.path.append('../..')
from utils.db_helper import get_image_data

In [None]:
GRAYSCALE=True
base_save_path = os.path.join("..", "..", "..", "exp", "DFT")

In [None]:
SOURCE_DIR_V1 = "C:/database/StyleGanv1"
SOURCE_DIR_V2 = "C:/database/StyleGanv2"
SOURCE_DIR_FFHQ = "C:/database/FFHQ"


In [None]:
def dft_from_single_channel(np_channel: np.ndarray)-> np.ndarray:
        np_fft = np.fft.fft2(np_channel)
        np_fft = np.fft.fftshift(np_fft)
        np_fft = np.log(np.abs(np_fft)+1e-3)
        #normalize
        fft_min = np.percentile(np_fft,5)
        fft_max = np.percentile(np_fft,95)
        np_fft = (np_fft - fft_min)/(fft_max - fft_min)
        np_fft[np_fft<0] = 0
        np_fft[np_fft>1] = 1
        return np_fft

In [None]:
def dft_from_image(np_img: np.ndarray)-> np.ndarray:
    np_fft = np.empty(np_img.shape)
    if len(np_img.shape) == 3:
        for i in range(np_img.shape[-1]):
            np_fft[:,:,i] = dft_from_single_channel(np_img[:,:,i])
    elif len(np_img.shape) == 2:
        np_fft = dft_from_single_channel(np_img)
    else:
        raise ValueError('Bad shape of the image')
    return np_fft


In [None]:
def dft_from_dataset(src_path: str, grayscale=False)-> np.ndarray:
    print("Loading dataset...")
    dataset_gen = get_image_data(src_path, grayscale=grayscale)
    final_type = np.float64
    print("Processing database...")
    np_dft_dataset = np.zeros(next(dataset_gen).shape, dtype=final_type)
    number_of_images = 0
    for np_image in dataset_gen:
        np_tmp_dft = dft_from_image(np_image)
        np_dft_dataset += np_tmp_dft
        number_of_images +=1
        if np_dft_dataset.max() > 0.95*np.finfo(final_type).max:
            raise ValueError("Dataset to large, datatype overflow")
            break
    return np_dft_dataset / number_of_images 


In [None]:
np_dft_ffhq = dft_from_dataset(SOURCE_DIR_FFHQ, grayscale=GRAYSCALE)
np_dft_style1 = dft_from_dataset(SOURCE_DIR_V1, grayscale=GRAYSCALE)
np_dft_style2 = dft_from_dataset(SOURCE_DIR_V2, grayscale=GRAYSCALE)

In [None]:
datasets = {
    'ffhq': np_dft_ffhq,
    'StyleGanv1': np_dft_style1,
    'StyleGanv2': np_dft_style2
}

if GRAYSCALE:
    color_channels = {
    "0": 'Gray',
    }
else:
    color_channels = {
        "0": 'R',
        "1": 'G',
        '2': 'B'
    }

# DFT from images

In [None]:
fig, axs = plt.subplots(nrows=len(datasets), ncols=len(color_channels), figsize=(15,15))
fig.suptitle("2D-DFT for different dataset")
row=0
for name, dataset in datasets.items():
    if len(color_channels) > 1:
        for col in len(color_channels):
            axs[row,col].imshow(dataset[:,:,col], cmap='gray')
            axs[row,col].set_title(f"{name}_{color_channels[str(col)]}")
    else:
            axs[row].imshow(dataset[:,:], cmap='gray')
            axs[row].set_title(f"{name}")
    row+=1
directory = 'Grayscale' if GRAYSCALE else 'RGB'
filename="DFT_comaparision.png"
save_path = os.path.join(base_save_path, directory, filename)
#print(save_path)
fig.savefig(save_path)


# Histograms from DFT

In [None]:
fig, axs = plt.subplots(nrows=len(datasets), ncols=len(color_channels), sharey=True, figsize=(15,15))
fig.suptitle("Average histograms from 2D-DFT images")
row=0
for name, dataset in datasets.items():
    if len(color_channels) > 1:
        for col in len(color_channels):
            hist, bins = np.histogram(dataset[:,:,col], density=True, bins=100)
            axs[row,col].plot(list(bins[:-1]), hist)
            axs[row,col].set_title(f"{name}_{color_channels[str(col)]}")
    else:
        hist, bins = np.histogram(dataset[:,:], density=True, bins=100)
        axs[row].plot(list(bins[:-1]), hist)
        axs[row].set_title(f"{name}")
    row+=1
directory = 'Grayscale' if GRAYSCALE else 'RGB'
filename="DFT_histogram_comaparision.png"
save_path = os.path.join(base_save_path, directory, filename)
#print(save_path)
fig.savefig(save_path)