In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import random_split
from PIL import Image 
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import torch
import matplotlib.pyplot as plt
import cv2
import numpy as np
from typing import Tuple
import torch.fft
from torch import Tensor

import torch
from torch import Tensor
from torch.distributions import Uniform
from torch.distributions.bernoulli import Bernoulli

class GaussianMixtureMask:
    """Applies a Gaussian Mixture Mask in the Fourier domain to an image.

    The mask is created using random Gaussian kernels, which are applied in
    the frequency domain.

    Attributes:
        num_gaussians: Number of Gaussian kernels to generate in the mixture mask.
        std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians.
    """

    def __init__(
        self, num_gaussians: int = 20, std_range: Tuple[float, float] = (10, 15)
    ):
        """Initializes GaussianMixtureMasks with the given parameters.

        Args:
            num_gaussians: Number of Gaussian kernels to generate in the mixture mask.
            std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians.
        """
        self.num_gaussians = num_gaussians
        self.std_range = std_range

    def gaussian_kernel(
        self, size: Tuple[int, int], sigma: Tensor, center: Tensor
    ) -> Tensor:
        """Generates a 2D Gaussian kernel.

        Args:
            size: Tuple specifying the dimensions of the Gaussian kernel (H, W).
            sigma: Tensor specifying the standard deviation of the Gaussian.
            center: Tensor specifying the center of the Gaussian kernel.

        Returns:
            A 2D Gaussian kernel tensor.
        """
        u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1]))
        u = u.to(sigma.device)
        v = v.to(sigma.device)
        u0, v0 = center
        gaussian = torch.exp(
            -((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2))
        )

        return gaussian

    def apply_gaussian_mixture_mask(
        self, freq_image: Tensor, num_gaussians: int, std: Tuple[float, float]
    ) -> Tensor:
        """Applies the Gaussian mixture mask to a frequency-domain image.

        Args:
            freq_image: Tensor representing the frequency-domain image of shape (C, H, W//2+1).
            num_gaussians: Number of Gaussian kernels to generate in the mask.
            std: Tuple specifying the standard deviation range for the Gaussians.

        Returns:
            Image tensor in frequency domain after applying the Gaussian mixture mask.
        """
        (C, U, V) = freq_image.shape
        mask = freq_image.new_ones(freq_image.shape)

        for _ in range(num_gaussians):
            u0 = torch.randint(0, U, (1,), device=freq_image.device)
            v0 = torch.randint(0, V, (1,), device=freq_image.device)
            center = torch.tensor((u0, v0), device=freq_image.device)
            sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0]

            g_kernel = self.gaussian_kernel((U, V), sigma, center)
            mask *= 1 - g_kernel.unsqueeze(0)

        filtered_freq_image = freq_image * mask
        return filtered_freq_image

    def __call__(self, freq_image: Tensor) -> Tensor:
        """Applies the Gaussian mixture mask transformation to the input frequency-domain image.

        Args:
            freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1).

        Returns:
            Image tensor in frequency domain after applying the Gaussian mixture mask.
        """
        return self.apply_gaussian_mixture_mask(
            freq_image, self.num_gaussians, self.std_range
        )



class IRFFT2DTransform:
    """Inverse 2D Fast Fourier Transform (IRFFT2D) Transformation.

    This transformation applies the inverse 2D Fast Fourier Transform (IRFFT2D)
    to an image in the frequency domain.

    Input:
        - Tensor of shape (C, H, W), where C is the number of channels.

    Output:
        - Tensor of shape (C, H, W), where C is the number of channels.
    """

    def __init__(self, shape: Tuple[int, int]):
        """
        Args:
            shape: The desired output shape (H, W) after applying the inverse FFT
        """
        self.shape = shape

    def __call__(self, freq_image: Tensor) -> Tensor:
        """Applies the inverse 2D Fast Fourier Transform (IRFFT2D) to the input tensor.

        Args:
            freq_image: A tensor in the frequency domain of shape (C, H, W).

        Returns:
            Tensor: Reconstructed image after applying IRFFT2D, of shape (C, H, W).
        """
        reconstructed_image: Tensor = torch.fft.irfft2(freq_image, s=self.shape)
        return reconstructed_image



class RFFT2DTransform:
    """2D Fast Fourier Transform (RFFT2D) Transformation.

    This transformation applies the 2D Fast Fourier Transform (RFFT2D)
    to an image, converting it from the spatial domain to the frequency domain.

    Input:
        - Tensor of shape (C, H, W), where C is the number of channels.

    Output:
        - Tensor of shape (C, H, W) in the frequency domain, where C is the number of channels.
    """

    def __call__(self, image: Tensor) -> Tensor:
        """Applies the 2D Fast Fourier Transform (RFFT2D) to the input image.

        Args:
            image: Input image as a Tensor of shape (C, H, W).

        Returns:
            Tensor: The image in the frequency domain after applying RFFT2D, of shape (C, H, W).
        """

        rfft_image: Tensor = torch.fft.rfft2(image)
        return rfft_image

from typing import Tuple

class AmplitudeRescaleTransform:
    """Implementation of amplitude rescaling transformation.

    This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it.

    Attributes:
        dist:
            Uniform distribution in `[m, n)` from which the scaling value will be selected.
        """

    def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None:
        self.dist = Uniform(range[0], range[1])

    def __call__(self, freq_image: Tensor) -> Tensor:
        amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2)

        phase = torch.atan2(freq_image.imag, freq_image.real)
        # p with shape (H, W)
        p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device)
        # Unsqueeze to add channel dimension.[]
        amplitude *= p.unsqueeze(0)
        real = amplitude * torch.cos(phase)
        imag = amplitude * torch.sin(phase)
        output = torch.complex(real, imag)

        return output
    

class RandomFrequencyMaskTransform:
    """2D Random Frequency Mask Transformation.

    This transformation applies a binary mask on the fourier transform,
    across all channels. A proportion of k frequencies are set to 0 with this.

    Input
        - Tensor: RFFT of a 2D Image (C, H, W) C-> No. of Channels
    Output
        - Tensor: The masked RFFT of the image

    """

    def __init__(self, k: Tuple[float, float] = (0.01, 0.1)) -> None:
        self.k = k

    def __call__(self, fft_image: Tensor) -> Tensor:
        k = np.random.uniform(low=self.k[0], high=self.k[1])

        # Every mask for every channel will have same frequencies being turned off i.e. being set to zero
        mask = (
            torch.rand(fft_image.shape[1:], device=fft_image.device) > k
        )  # mask_type: (H, W)

        # Do not mask zero frequency mode to retain majority of the semantic information.
        # Please refer https://arxiv.org/abs/2312.02205
        mask[0, 0] = 1

        # Adding channel dimension
        mask = mask.unsqueeze(0)

        masked_frequency_spectrum_image = fft_image * mask

        return masked_frequency_spectrum_image
    

class PhaseShiftTransform:
    """Implementation of phase shifting transformation.


    Applies a random phase shift `theta` (positive or negative) to the Fourier spectrum (`freq_image`) of the image and returns the transformed spectrum.

    Attributes:
        dist:
            A uniform distribution in the range `[p, q)` from which the magnitude of the
            phase shift `theta` is selected.
        include_negatives:
            A flag indicating whether negative values of `theta` should be included.
            If `True`, both positive and negative shifts are applied.
        sign_dist:
            A Bernoulli distribution used to decide the sign of `theta`, based on a
            given probability `sign_probability`, if negative values are included.
    """

    def __init__(
        self,
        range: Tuple[float, float] = (0.4, 0.7),
        include_negatives: bool = False,
        sign_probability: float = 0.5,
    ) -> None:
        self.dist = Uniform(range[0], range[1])
        self.include_negatives = include_negatives
        if include_negatives:
            self.sign_dist = Bernoulli(sign_probability)

    def __call__(self, freq_image: Tensor) -> Tensor:
        # Calculate amplitude and phase
        amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2)
        phase = torch.atan2(freq_image.imag, freq_image.real)

        # Sample a random phase shift θ
        theta = self.dist.sample().to(freq_image.device)

        if self.include_negatives:
            # Determine sign for shift: +θ or -θ
            sign = self.sign_dist.sample().to(freq_image.device)
            # Apply random sign directly to theta
            theta = torch.where(sign == 1, theta, -theta)

        # Adjust the phase
        phase_shifted = phase + theta

        # Recreate the complex spectrum with adjusted phase
        real = amplitude * torch.cos(phase_shifted)
        imag = amplitude * torch.sin(phase_shifted)
        output = torch.complex(real, imag)

        return output

def GaussianMixture(image,num_gaussians=15):
    gaussian_mixture = GaussianMixtureMask(num_gaussians=num_gaussians, std_range=(10, 15))
    transform=RFFT2DTransform()
    
    image_tensor = torch.unsqueeze(torch.tensor(image, dtype=torch.float32, device=device), dim=0)
    freq_image=transform(image_tensor)
    transformed_image_tensor = gaussian_mixture(freq_image)
    
    image_size = freq_image.shape[1:]
    original_height = image_size[0]
    original_width = 2 * (image_size[1] - 1)

    original_shape = (original_height, original_width)

    irfft2d_transform = IRFFT2DTransform(original_shape)
    
    transformed_image_tensor = irfft2d_transform(transformed_image_tensor)
    
    transformed_image = transformed_image_tensor.squeeze()
    return transformed_image


def AmplitudeRescale(image):
    apl_rescale = AmplitudeRescaleTransform(range=(0.8, 1.75))
    transform=RFFT2DTransform()
    
    image_tensor = torch.unsqueeze(torch.tensor(image, dtype=torch.float32, device=device), dim=0)
    freq_image=transform(image_tensor)
    transformed_image_tensor = apl_rescale(freq_image)
    
    image_size = freq_image.shape[1:]
    original_height = image_size[0]
    original_width = 2 * (image_size[1] - 1)

    original_shape = (original_height, original_width)

    irfft2d_transform = IRFFT2DTransform(original_shape)
    
    transformed_image_tensor = irfft2d_transform(transformed_image_tensor)
    
    transformed_image = transformed_image_tensor.squeeze()
    return transformed_image


def RandomFrequencyMask(image):
    rand_freq = RandomFrequencyMaskTransform(k=(0.01, 0.1))
    transform=RFFT2DTransform()
    
    image_tensor = torch.unsqueeze(torch.tensor(image, dtype=torch.float32, device=device), dim=0)
    freq_image=transform(image_tensor)
    transformed_image_tensor = rand_freq(freq_image)
    
    image_size = freq_image.shape[1:]
    original_height = image_size[0]
    original_width = 2 * (image_size[1] - 1)

    original_shape = (original_height, original_width)

    irfft2d_transform = IRFFT2DTransform(original_shape)
    
    transformed_image_tensor = irfft2d_transform(transformed_image_tensor)
    
    transformed_image = transformed_image_tensor.squeeze()
    return transformed_image


def PhaseShift(image):
    phase_shift = PhaseShiftTransform(range=(0.4, 0.7), include_negatives=True, sign_probability=0.5)
    transform=RFFT2DTransform()
    
    image_tensor = torch.unsqueeze(torch.tensor(image, dtype=torch.float32, device=device), dim=0)
    freq_image=transform(image_tensor)
    transformed_image_tensor = phase_shift(freq_image)
    
    image_size = freq_image.shape[1:]
    original_height = image_size[0]
    original_width = 2 * (image_size[1] - 1)

    original_shape = (original_height, original_width)

    irfft2d_transform = IRFFT2DTransform(original_shape)
    
    transformed_image_tensor = irfft2d_transform(transformed_image_tensor)
    
    transformed_image = transformed_image_tensor.squeeze()
    return transformed_image

In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset

class AnimalDataset(Dataset):
    def __init__(self, root_dir, transform=None,augmentation_fn=None):
        """
        Args:
          root_dir (str): Directory with subfolders per class.
          transform (callable, optional): Transform to apply to PIL images.
        """
        self.root_dir = root_dir
        self.transform = transform

        # Build list of (image_path, label) pairs
        self.samples = []
        self.class_to_idx = {}
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_folder = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_folder):
                continue
            self.class_to_idx[class_name] = idx
            for fname in os.listdir(class_folder):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    path = os.path.join(class_folder, fname)
                    self.samples.append((path, idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [4]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

import random

def my_fourier_augmentation(image):
    transforms = [
        GaussianMixture,
        AmplitudeRescale,
        RandomFrequencyMask,
        PhaseShift
    ]
    lmage = random.choice(transforms)(image)
    return image

# 1) Transforms: resize, to tensor, normalize (ImageNet stats)
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std =[0.229,0.224,0.225])
])

# 2) Instantiate full dataset
full_dataset = AnimalDataset(root_dir='/kaggle/input/animals10/raw-img', transform=data_transform, augmentation_fn=my_fourier_augmentation)


# 3) Split into train (80%) and test (20%)
train_size = int(0.8 * len(full_dataset))
test_size  = len(full_dataset) - train_size
train_ds, test_ds = random_split(full_dataset, [train_size, test_size])
# Since PyTorch v1.13, you can also pass floats: [0.8, 0.2] :contentReference[oaicite:6]{index=6}

# 4) DataLoaders for batching
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, num_workers=4)


In [5]:
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNet-18 and replace final FC layer
model = models.resnet18(pretrained=True)
num_classes = len(full_dataset.class_to_idx)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 179MB/s] 


In [6]:
from tqdm import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss, running_corrects = 0.0, 0

    # Wrap the training loader with tqdm
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        running_corrects += (preds == labels).sum().item()

    epoch_loss = running_loss / train_size
    epoch_acc = running_corrects / train_size * 100
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%")

    # Evaluation
    model.eval()
    test_corrects, test_total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            test_corrects += (preds == labels).sum().item()
            test_total += labels.size(0)

    test_acc = test_corrects / test_total * 100
    print(f"          Test Acc: {test_acc:.2f}%\n")


                                                             

Epoch 1/10 - Train Loss: 0.3666, Train Acc: 90.15%




          Test Acc: 94.65%



                                                             

Epoch 2/10 - Train Loss: 0.1912, Train Acc: 94.09%




          Test Acc: 94.71%



                                                             

Epoch 3/10 - Train Loss: 0.1695, Train Acc: 94.69%




          Test Acc: 95.30%



                                                             

Epoch 4/10 - Train Loss: 0.1536, Train Acc: 95.29%




          Test Acc: 94.92%



                                                             

Epoch 5/10 - Train Loss: 0.1495, Train Acc: 95.28%




          Test Acc: 94.84%



                                                             

Epoch 6/10 - Train Loss: 0.1456, Train Acc: 95.23%




          Test Acc: 95.47%



                                                             

Epoch 7/10 - Train Loss: 0.1404, Train Acc: 95.57%




          Test Acc: 94.86%



                                                             

Epoch 8/10 - Train Loss: 0.1394, Train Acc: 95.65%




          Test Acc: 95.02%



                                                             

Epoch 9/10 - Train Loss: 0.1313, Train Acc: 95.69%




          Test Acc: 95.19%



                                                              

Epoch 10/10 - Train Loss: 0.1313, Train Acc: 95.84%




          Test Acc: 95.53%

