Libraries

In [54]:
## Standard libraries
import os
import math
import numpy as np
import time
import matplotlib.pyplot as plt
import numpy.fft as fft
import cv2
from PIL import Image

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

#3 Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

## Progress bar
from tqdm.notebook import tqdm


Constants

In [55]:
# Path to the folder where the datasets are stored
DATASET_PATH = "../data"
# Fetching the device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

Using device cuda:0


Define DFT transform

In [56]:
#helper functions
def get_centered_dft(img):
    dft = np.fft.fft2(img)
    centered_dft = np.fft.fftshift(dft)
    return centered_dft

def restore_img_from_centered_dft(centered_dft):
    dft = np.fft.fftshift(centered_dft)
    img = np.fft.ifft2(dft)
    return img

def flat_transform(dft):
    real = dft.real
    imag = dft.imag
    H, W = dft.shape
    flat_dft = np.zeros((H, 2 * W))
    flat_dft[:, ::2] = real
    flat_dft[:, 1::2] = imag
    return flat_dft

def inverse_flat_transform(flat_dft):
    real = flat_dft[:, ::2]
    imag = flat_dft[:, 1::2]
    dft = np.vectorize(complex)(real, imag)
    return dft
    

def show_centered_dft(centered_dft):
    figures_dft_phase = np.angle(centered_dft)
    figures_dft_magnitude = np.abs(centered_dft)
    figures_dft_log_magnitude = np.log10(figures_dft_magnitude)

    phase_stretched = (figures_dft_phase - np.min(figures_dft_phase)) / (np.max(figures_dft_phase) - np.min(figures_dft_phase))
    log_magnitude_stretched = (figures_dft_log_magnitude - np.min(figures_dft_log_magnitude)) / (np.max(figures_dft_log_magnitude) - np.min(figures_dft_log_magnitude))

    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(phase_stretched, cmap='gray')
    axs[0].set_title('Phase Component')
    axs[1].imshow(log_magnitude_stretched, cmap='gray')
    axs[1].set_title('Logarithm of Magnitude Component')
    plt.show()

def center_transform(image):
    h, w = image.shape[:2]
    n = min(h * w, min(h, w) ** 2)
    center = (h - 1) // 2, (w - 1) // 2
    indices = [(i, j) for i in range(h) for j in range(w)]
    original_indices = [(i, j) for i in range(h) for j in range(w)]
    indices.sort(key=lambda x: (x[0] - center[0]) ** 2 + (x[1] - center[1]) ** 2)
    indices = indices[:n]
    new_image = np.zeros_like(image)
    for i in range(n):
        new_image[original_indices[i]] = image[indices[i]]
    return new_image

def inverse_center_transform(image):
    h, w = image.shape[:2]
    n = min(h * w, min(h, w) ** 2)
    center = (h - 1) // 2, (w - 1) // 2
    indices = [(i, j) for i in range(h) for j in range(w)]
    original_indices = [(i, j) for i in range(h) for j in range(w)]
    indices.sort(key=lambda x: (x[0] - center[0]) ** 2 + (x[1] - center[1]) ** 2)
    indices = indices[:n]
    new_image = np.zeros_like(image)
    for i in range(n):
        new_image[indices[i]] = image[original_indices[i]]
    return new_image

def DFT_transform(img):
    result = torch.zeros((img.shape[0], img.shape[1], img.shape[2]*2))
    for i in range(result.shape[0]): # for every channel
        temp = img[i]
        temp = get_centered_dft(temp)
        temp = center_transform(temp)
        temp = flat_transform(temp)
        result[i] = torch.Tensor(temp)
    return result

def inverse_DFT_transform(img):
    result = torch.zeros((img.shape[0], img.shape[1], img.shape[2]*2))
    for i in range(result.shape[0]): # for every channel
        temp = img[i]
        temp = inverse_flat_transform(img)
        temp = inverse_center_transform(temp)
        temp = restore_img_from_centered_dft(temp)
        result[i] = torch.tensor(temp)
    return result.real

Create DFT data

In [57]:
# Convert images from 0-1 to 0-255 (integers). We use the long datatype as we will use the images as labels as well
def discretize(sample):
    return (sample * 255).to(torch.long)

# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(),
                                discretize])

# Loading the training dataset.
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# Get the whole dataset
whole_dataset = train_dataset + test_set

# We define a set of data loaders that we can use for various purposes later.
# train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=0)
# test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=0)

In [58]:
Origin_max_value = 0
Origin_min_value = 0

for img, _ in tqdm(whole_dataset):
    img_DFT = DFT_transform(img)
    if img_DFT.max().item() > Origin_max_value:
        Origin_max_value = img_DFT.max().item()
    if img_DFT.min().item() < Origin_min_value:
        Origin_min_value = img_DFT.min().item()
        
print("Original DFT max value: {}".format(Origin_max_value))
print("Original DFT min value: {}".format(Origin_min_value))

  0%|          | 0/70000 [00:00<?, ?it/s]

Original DFT max value: 79483.0
Original DFT min value: -40394.43359375


In [59]:
train_data_path = "../data/MNIST_DFT/train"
test_data_path = "../data/MNIST_DFT/test"
idx = -1

for img, tag in tqdm(train_dataset, desc= "generating train dataset"):
    idx += 1
    img_DFT = DFT_transform(img)
    img_DFT = ((img_DFT - Origin_min_value) * 256 / (Origin_max_value - Origin_min_value)).to(torch.uint8)
    img_DFT_rgb = torch.cat([img_DFT] * 3, dim=0)
    hwc_array = np.transpose(img_DFT_rgb.numpy(), (1,2,0))
    image = Image.fromarray(hwc_array)
    img_path = os.path.join(train_data_path, "train_id{}_tag{}.jpg".format(idx, tag))
    image.save(img_path)

for img, tag in tqdm(test_set, desc= "generating test dataset"):
    idx += 1
    img_DFT = DFT_transform(img)
    img_DFT = ((img_DFT - Origin_min_value) * 256 / (Origin_max_value - Origin_min_value)).to(torch.uint8)
    img_DFT_rgb = torch.cat([img_DFT] * 3, dim=0)
    hwc_array = np.transpose(img_DFT_rgb.numpy(), (1,2,0))
    image = Image.fromarray(hwc_array)
    img_path = os.path.join(test_data_path, "test_id{}_tag{}.jpg".format(idx, tag))
    image.save(img_path)


generating train dataset:   0%|          | 0/60000 [00:00<?, ?it/s]

generating test dataset:   0%|          | 0/10000 [00:00<?, ?it/s]