In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import math

def animals():
    # Define the transformations to apply to the images (like converting to tensors)
    transform = transforms.Compose([
        transforms.Resize((1080, 1080)),
        transforms.Grayscale(num_output_channels=1),
        #transforms.Resize((28, 28)),
        transforms.ToTensor(),  # Convert images to PyTorch tensors
    ])

    # Download and load the training and test datasets
    train_dataset = datasets.OxfordIIITPet(root='./animals', split="trainval", download=True, transform=transform)
    test_dataset = datasets.OxfordIIITPet(root='./animals', split="test", download=True, transform=transform)

    # Create data loaders to iterate through the dataset
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Example: Access a batch of images and labels
    images, labels = next(iter(train_loader))


    # Plot the first 6 images in the batch
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.imshow(images[i].permute(1, 2, 0))
        plt.title(f"Label: {labels[i].item()}")
        plt.axis("off")

    plt.show()
    return train_loader, test_loader

train_loader, test_loader = animals()



In [None]:
import torch.fft as ft
images, labels = next(iter(train_loader))
image = images[4]

plt.imshow(image.permute(1, 2, 0))
plt.show()

f_transform = ft.fft2(image)
f_transform_shifted = torch.fft.fftshift(f_transform)
magnitude_spectrum = torch.log(torch.abs(f_transform_shifted) + 1e-10)  # Log for better visualization
#magnitude_spectrum = (torch.abs(f_transform_shifted) + 1e-10)

plt.figure(figsize=(6, 6))
plt.imshow(magnitude_spectrum.squeeze())
plt.title('Magnitude Spectrum (Frequency Domain)')
plt.axis('off')
plt.show()


In [None]:
# Create a Low-Pass Filter
def low_pass_filter(shape, cutoff):
    rows, cols = shape[1], shape[2]
    center_row, center_col = rows // 2, cols // 2

    # Create a grid of (x,y) coordinates
    y = torch.arange(rows).unsqueeze(1).expand(rows, cols)
    x = torch.arange(cols).unsqueeze(0).expand(rows, cols)

    # Calculate the distance from the center
    distance = torch.sqrt((x - center_col) ** 2 + (y - center_row) ** 2)

    # Create a mask for low frequencies
    mask = (distance <= cutoff).float()  # 1 inside the circle, 0 outside
    return mask
cutoff_frequency = 30  # Adjust this value based on your needs
print(f_transform_shifted.shape)
lp_filter = low_pass_filter(f_transform_shifted.shape, cutoff_frequency)

# Apply the Low-Pass Filter
filtered_transform = f_transform_shifted * lp_filter
magnitude_spectrum_filtered = torch.log(torch.abs(filtered_transform) + 1e-10)
# Inverse Fourier Transform to get the filtered image back
filtered_image_shifted = torch.fft.ifftshift(filtered_transform)  # Shift back
filtered_image = torch.fft.ifft2(filtered_image_shifted)  # Inverse FFT
filtered_image_real = torch.real(filtered_image)
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Original Image')
plt.imshow(image.permute(1, 2, 0))

plt.subplot(1, 3, 2)
plt.title('filtered transform')
plt.imshow(magnitude_spectrum_filtered.squeeze())

plt.subplot(1, 3, 3)
plt.title('Low-Pass Filtered Image')
plt.imshow(filtered_image_real.squeeze())
plt.show()



In [None]:
# Create a high-Pass Filter
def high_pass_filter(shape, cutoff):
    rows, cols = shape[1], shape[2]
    center_row, center_col = rows // 2, cols // 2

    # Create a grid of (x,y) coordinates
    y = torch.arange(rows).unsqueeze(1).expand(rows, cols)
    x = torch.arange(cols).unsqueeze(0).expand(rows, cols)

    # Calculate the distance from the center
    distance = torch.sqrt((x - center_col) ** 2 + (y - center_row) ** 2)

    # Create a mask for low frequencies
    mask = (distance >= cutoff).float()  # 1 inside the circle, 0 outside
    return mask
cutoff_frequency = 30  # Adjust this value based on your needs
print(f_transform_shifted.shape)
lp_filter = high_pass_filter(f_transform_shifted.shape, cutoff_frequency)

# Apply the Low-Pass Filter
filtered_transform = f_transform_shifted * lp_filter
magnitude_spectrum_filtered = torch.log(torch.abs(filtered_transform) + 1e-10)
# Inverse Fourier Transform to get the filtered image back
filtered_image_shifted = torch.fft.ifftshift(filtered_transform)  # Shift back
filtered_image = torch.fft.ifft2(filtered_image_shifted)  # Inverse FFT
filtered_image_real = torch.real(filtered_image)
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Original Image')
plt.imshow(image.permute(1, 2, 0))

plt.subplot(1, 3, 2)
plt.title('filtered transform')
plt.imshow(magnitude_spectrum_filtered.squeeze())

plt.subplot(1, 3, 3)
plt.title('High-Pass Filtered Image')
plt.imshow(filtered_image_real.squeeze())
plt.show()

In [None]:
rows, cols = f_transform_shifted.shape[1], f_transform_shifted.shape[2]
crow, ccol = rows // 2, cols // 2

# Increase the region to zero out to make changes more noticeable
f_transform_shifted[crow - 50:crow + 50, ccol - 50:ccol + 50] = 0

# Inverse FFT shift and transformation back to the spatial domain
f_transform_inverse = torch.fft.ifftshift(f_transform_shifted)
img_back = torch.fft.ifft2(f_transform_inverse)
img_back = torch.abs(img_back)

# Normalize the image for better visualization
img_back = (img_back - img_back.min()) / (img_back.max() - img_back.min())

# Visualize the filtered image
plt.imshow(img_back.squeeze(), cmap='gray', vmin=0, vmax=1)
plt.title('Filtered Image')
plt.xticks([]), plt.yticks([])
plt.show()

In [None]:
magnitude = torch.abs(f_transform_shifted)

# Function to compress image using Fourier Transform
def compress_image(f_transform_shifted, percentage):
    # Get the total number of coefficients
    total_coefficients = f_transform_shifted.numel()
    
    # Calculate the number of coefficients to keep
    num_keep = int(total_coefficients * percentage)
    
    # Flatten the magnitude to get the indices
    flattened_magnitude = magnitude.flatten()
    
    # Get indices of the largest coefficients
    indices = torch.argsort(flattened_magnitude.flatten(), descending=True)[:num_keep]
    
    # Create a mask to retain only the significant coefficients
    mask = torch.zeros_like(f_transform_shifted, dtype=torch.complex64)
    mask.flatten()[indices] = f_transform_shifted.flatten()[indices]
    
    # Apply the mask to retain coefficients
    compressed_transform = mask
    
    # Inverse Fourier Transform to get the filtered image back
    compressed_image_shifted = torch.fft.ifftshift(compressed_transform)  # Shift back
    compressed_image = torch.fft.ifft2(compressed_image_shifted)  # Inverse FFT
    
    # Take the real part and normalize the image
    compressed_image_real = torch.real(compressed_image)
    compressed_image_real = (compressed_image_real - compressed_image_real.min()) / (compressed_image_real.max() - compressed_image_real.min())  # Normalize
    
    return compressed_image_real, mask

# Vary the percentage of coefficients used for compression
percentages = [0.001, 0.005, 0.01, 0.015]  # 10%, 20%, 50%, 80%
reconstructed_images = []

# Create plots for each percentage
plt.figure(figsize=(8,18))
for i, percentage in enumerate(percentages):
    compressed_image_real, mask = compress_image(f_transform_shifted, percentage)
    reconstructed_images.append(compressed_image_real)
    mask = torch.log(torch.abs(mask) + 1e-10)
    
    plt.subplot(len(percentages), 2, 2 * i + 1)
    plt.title(f'Compressed Image ({(percentage * 100)}% Coefficients)')
    plt.imshow(compressed_image_real.squeeze())
    plt.axis('off')

    plt.subplot(len(percentages), 2, 2 * i + 2)
    plt.title(f'Fourier transform ({(percentage * 100)}% Coefficients)')
    plt.imshow(mask.squeeze())
    plt.axis('off')

plt.tight_layout()
plt.show()

plt.figure(figsize=(8,10))
compressed_image_real, mask = compress_image(f_transform_shifted, 0.5)
reconstructed_images.append(compressed_image_real)
mask = torch.log(torch.abs(mask) + 1e-10)

plt.subplot(1, 2, 1)
plt.title(f'Compressed Image ({(0.5 * 100)}% Coefficients)')
plt.imshow(compressed_image_real.squeeze())
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title(f'Fourier transform ({(0.5 * 100)}% Coefficients)')
plt.imshow(mask.squeeze())
plt.axis('off')
plt.tight_layout()
plt.show()