In [20]:
import torch
import torch.nn as nn
import torch.fft as fft
from torchvision.models import vit_b_16

class FreqNetViT(nn.Module):
    def __init__(self, num_classes=2, patch_size=16, im_width=224, im_height=224):
        super(FreqNetViT, self).__init__()

        # Load pre-trained Vision Transformer (ViT) model
        self.vit = vit_b_16(pretrained=True)
        self.vit.heads = nn.Linear(self.vit.heads.head.in_features, num_classes)  # Update the final layer

        # Patch size (used to define high-pass filter size)
        self.patch_size = patch_size
        self.im_width = im_width
        self.im_height = im_height

        self.high_pass_filter1 = self.create_high_pass_filter(self.im_width)
        # High-pass filter for extracting high-frequency information
        self.high_pass_filter = self.create_high_pass_filter(self.patch_size)

        # Frequency convolutional layers for amplitude and phase
        self.freq_conv_amp = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.freq_conv_phase = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)

    def forward(self, x):
        # Step 1: Convert input images to the frequency domain and apply high-pass filter
        # x.shape:  torch.Size([1, 3, 224, 224])
        x_freq = self.apply_fft_highpass(x)
        # x_freq.shape: torch.Size([1, 3, 224, 224])

        # Step 2: Apply frequency convolution to the high-frequency components
        x_freq_convolved = self.frequency_convolution(x_freq)

        # Step 3: Pass the frequency-transformed images to the Vision Transformer (ViT)
        x_vit = self.vit(x_freq_convolved)

        return x_vit

    def apply_fft_highpass(self, x):
        """
        Convert image to frequency domain, apply high-pass filter, and convert back.
        """
        # x.shape is [1, 3, 224, 224]

        # FFT: Transform the input images to the frequency domain
        x_fft = fft.fftn(x, dim=(-2, -1))  # Apply FFT over spatial dimensions (height, width)
        # x_fft.shape: torch.Size([1, 3, 224, 224])


        # Shift zero frequency to the center
        x_fft_shift = fft.fftshift(x_fft)
        # x_fft_shift.shape: torch.Size([1, 3, 224, 224])

        # self.high_pass_filter1.to(x.device).shape: torch.Size([224, 224])

        # Apply high-pass filter to remove low-frequency components
        x_fft_high = x_fft_shift * self.high_pass_filter1.to(x.device)
        # x_fft_high.shape: torch.Size([1, 3, 224, 224])

        # Inverse FFT: Convert back to the spatial domain
        x_fft_high_shifted = fft.ifftshift(x_fft_high)  # Shift frequencies back
        # x_fft_high_shifted.shape:  torch.Size([1, 3, 224, 224])
        x_ifft = torch.real(fft.ifftn(x_fft_high_shifted, dim=(-2, -1)))  # Inverse FFT
        # x_ifft.shape: torch.Size([1, 3, 224, 224])

        return x_ifft

    def create_high_pass_filter(self, patch_size):
        """
        Create a high-pass filter to extract high-frequency components from patches.
        """
        # Initialize filter to ones (no filtering)
        filter = torch.ones(patch_size, patch_size)

        # Set a central region to zero (to remove low frequencies)
        center_x, center_y = patch_size // 2, patch_size // 2
        filter[center_x - patch_size//4 : center_x + patch_size//4,
               center_y - patch_size//4 : center_y + patch_size//4] = 0

        return filter

    def frequency_convolution(self, x):
        """
        Apply convolutional layers in the frequency domain on amplitude and phase spectra.
        """
        # FFT: Convert feature maps to the frequency domain
        x_fft = fft.fftn(x, dim=(-2, -1))  # FFT on spatial dimensions (height, width)

        # Separate amplitude and phase
        amp = torch.abs(x_fft)  # Amplitude spectrum
        phase = torch.angle(x_fft)  # Phase spectrum

        # Apply convolutions in the frequency space
        amp_conv = self.freq_conv_amp(amp)  # Convolution on amplitude
        phase_conv = self.freq_conv_phase(phase)  # Convolution on phase

        # Reconstruct the feature maps using the modified amplitude and phase
        x_fft_new = torch.polar(amp_conv, phase_conv)

        # Inverse FFT: Convert back to spatial domain
        x_ifft = torch.real(fft.ifftn(x_fft_new, dim=(-2, -1)))
        # x_ifft.shape:  torch.Size([1, 3, 224, 224])

        return x_ifft

# Instantiate the model
model = FreqNetViT(num_classes=2)

# Example input: A batch of images with size (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224)  # 1 image, 3 channels (RGB), 224x224 resolution

# Forward pass
output = model(input_tensor)

# Output shape
print("Output shape:", output.shape)



x_ifft.shape:  torch.Size([1, 3, 224, 224])
Output shape: torch.Size([1, 2])
