In [70]:
import torch
import torch.nn as nn
import torch.fft as fft

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        return self.fc_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class FreqNetSimpleTransformer(nn.Module):
    def __init__(self, num_classes=2, patch_size=16, im_width=224, im_height=224, d_model=768, num_heads=8, num_layers=6, dropout=0.1):
        super(FreqNetSimpleTransformer, self).__init__()

        # Parameters
        self.patch_size = patch_size
        self.im_width = im_width
        self.im_height = im_height

        # Create high-pass filter
        self.high_pass_filter = self.create_high_pass_filter(self.patch_size)
        self.high_pass_filter1 = self.create_high_pass_filter(self.im_width)

        # Frequency convolutional layers for amplitude and phase
        self.freq_conv_amp = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.freq_conv_phase = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)

        # Positional embedding and patch embedding
        self.pos_embedding = nn.Parameter(torch.randn((im_width // patch_size) ** 2 + 1, 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, d_model)

        # Transformer layers
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, dropout=dropout, forward_expansion=4) for _ in range(num_layers)]
        )

        # Final classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):

        # x.shape:  torch.Size([1, 3, 224, 224])
        # Step 1: Convert input images to the frequency domain and apply high-pass filter
        x_freq = self.apply_fft_highpass(x)

        # Step 2: Apply frequency convolution to the high-frequency components
        x_freq_convolved = self.frequency_convolution(x_freq)
        # x_freq_convolved.shape:  torch.Size([1, 3, 224, 224])

        # Step 3: Convert the image into patches and embed
        x_patches = self.create_patches(x_freq_convolved)
        # x_patches.shape:  torch.Size([196, 1, 768]) # here 1 is the batch size...if batch size is 2 , then x_patches.shape: torch.Size([196, 2, 768])


        # Step 4: Add positional encoding
        n_patches,batch_size,   _ = x_patches.shape # n_patches = 196, batch_size = 1
        # self.cls_token.shape: torch.Size([1, 1, 768])...when batch size is 2, then also self.cls_token.shape: torch.Size([1, 1, 768])
        cls_tokens = self.cls_token.expand(-1, batch_size, -1) #to repeat the class token for each batch.

        # cls_tokens.shape: torch.Size([1, 1, 768]) ...if batch size was 2, then cls_tokens.shape: torch.Size([1, 2, 768])
        x_patches = torch.cat((cls_tokens, x_patches), dim=0)
        # x_patches.shape: torch.Size([197, 1, 768])... if batch size is 2, then x_patches.shape: torch.Size([197, 2, 768])
        # self.pos_embedding.shape: torch.Size([197, 1, 768])...if batch size is 2 then also self.pos_embedding.shape: torch.Size([197, 1, 768])

        ###################################################
        # if batch size is 2
        # x_patches.shape: torch.Size([197, 2, 768])
        # self.pos_embedding.shape: torch.Size([197, 1, 768])
        ###################################################

        # self.pos_embedding[:n_patches + 1,:, :].shape: torch.Size([197, 1, 768])...if batch size is 2, then also self.pos_embedding[:n_patches + 1,:, :].shape: torch.Size([197, 1, 768])
        x_patches += self.pos_embedding[:n_patches + 1,:, :]
        # x_patches.shape: torch.Size([197, 1, 768])...if batch size is 2, then x_patches.shape: torch.Size([197, 2, 768])


        # Step 5: Pass through transformer layers
        for transformer_block in self.transformer_blocks:
            x_patches = transformer_block(x_patches, x_patches, x_patches, mask=None)

        # x_patches.shape: torch.Size([197, 1, 768]) ... if batch_size is 2 then x_patches.shape: torch.Size([197, 2, 768])


        # Step 6: Classification using the cls_token output
        out = self.fc(x_patches[0])

        return out

    def create_patches(self, x):
        """
        Convert input images to patches and flatten them for transformer input.
        """
        batch_size, channels, height, width = x.shape
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        # patches.shape: torch.Size([1, 3, 14, 14, 16, 16])
        patches = patches.contiguous().view(batch_size, channels, -1, self.patch_size * self.patch_size)  # Flatten patches
        # patches.shape: torch.Size([1, 3, 196, 256])
        patches = patches.permute(2, 0, 1, 3).contiguous().view(-1, batch_size, self.patch_size * self.patch_size * channels)  # Rearrange for transformer
        # patches.shape :  torch.Size([196, 1, 768])
        patches = self.patch_to_embedding(patches)
        # patches.shape: torch.Size([196, 1, 768])
        return patches

    def apply_fft_highpass(self, x):
        """
        Convert image to frequency domain, apply high-pass filter, and convert back.
        """
        x_fft = fft.fftn(x, dim=(-2, -1))  # Apply FFT over spatial dimensions (height, width)
        x_fft_shift = fft.fftshift(x_fft)  # Shift zero frequency to the center

        # Apply high-pass filter to remove low-frequency components
        x_fft_high = x_fft_shift * self.high_pass_filter1.to(x.device)

        # Inverse FFT: Convert back to the spatial domain
        x_fft_high_shifted = fft.ifftshift(x_fft_high)  # Shift frequencies back
        x_ifft = torch.real(fft.ifftn(x_fft_high_shifted, dim=(-2, -1)))  # Inverse FFT

        return x_ifft

    def create_high_pass_filter(self, patch_size):
        """
        Create a high-pass filter to extract high-frequency components from patches.
        """
        filter = torch.ones(patch_size, patch_size)
        center_x, center_y = patch_size // 2, patch_size // 2
        filter[center_x - patch_size // 4:center_x + patch_size // 4,
               center_y - patch_size // 4:center_y + patch_size // 4] = 0  # Zero central region to keep high-frequencies
        return filter

    def frequency_convolution(self, x):
        """
        Apply convolutional layers in the frequency domain on amplitude and phase spectra.
        """
        x_fft = fft.fftn(x, dim=(-2, -1))  # FFT on spatial dimensions (height, width)

        amp = torch.abs(x_fft)  # Amplitude spectrum
        phase = torch.angle(x_fft)  # Phase spectrum

        # Apply convolutions in the frequency space
        amp_conv = self.freq_conv_amp(amp)  # Convolution on amplitude
        phase_conv = self.freq_conv_phase(phase)  # Convolution on phase

        # Reconstruct the feature maps using the modified amplitude and phase
        x_fft_new = torch.polar(amp_conv, phase_conv)

        # Inverse FFT: Convert back to spatial domain
        x_ifft = torch.real(fft.ifftn(x_fft_new, dim=(-2, -1)))

        return x_ifft

# Example usage
if __name__ == "__main__":
    # Example input tensor: batch of images with 3 channels, 224x224 dimensions
    input_tensor = torch.randn(1, 3, 224, 224)

    # Initialize model
    model = FreqNetSimpleTransformer(num_classes=2, patch_size=16, im_width=224, im_height=224, d_model=768, num_heads=8, num_layers=6, dropout=0.1)

    # Forward pass
    output = model(input_tensor)

    # Output shape
    print("Output shape:", output.shape)

Output shape: torch.Size([1, 2])


In [5]:
patch_size = 16

x = torch.randn(1, 3, 224, 224)
h = x.unfold(2, patch_size, patch_size)

In [7]:
h.shape

torch.Size([1, 3, 14, 224, 16])

In [8]:
import torch

# Smaller tensor for demonstration
x = torch.arange(1, 1 + 1 * 3 * 8 * 8).view(1, 3, 8, 8)
patch_size = 4

# Apply unfold to simulate patch extraction
h = x.unfold(2, patch_size, patch_size)

# Prepare to display results
x_values = x.numpy()
h_values = h.numpy()

x_values, h_values

(array([[[[  1,   2,   3,   4,   5,   6,   7,   8],
          [  9,  10,  11,  12,  13,  14,  15,  16],
          [ 17,  18,  19,  20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29,  30,  31,  32],
          [ 33,  34,  35,  36,  37,  38,  39,  40],
          [ 41,  42,  43,  44,  45,  46,  47,  48],
          [ 49,  50,  51,  52,  53,  54,  55,  56],
          [ 57,  58,  59,  60,  61,  62,  63,  64]],
 
         [[ 65,  66,  67,  68,  69,  70,  71,  72],
          [ 73,  74,  75,  76,  77,  78,  79,  80],
          [ 81,  82,  83,  84,  85,  86,  87,  88],
          [ 89,  90,  91,  92,  93,  94,  95,  96],
          [ 97,  98,  99, 100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109, 110, 111, 112],
          [113, 114, 115, 116, 117, 118, 119, 120],
          [121, 122, 123, 124, 125, 126, 127, 128]],
 
         [[129, 130, 131, 132, 133, 134, 135, 136],
          [137, 138, 139, 140, 141, 142, 143, 144],
          [145, 146, 147, 148, 149, 150, 151, 152],
      

In [58]:
# Create integer tensor 'a' of shape (2, 3, 4)
a = torch.randint(0, 10, (2, 3, 4), dtype=torch.int)

# Create integer tensor 'b' of shape (2, 1, 4)
b = torch.randint(0, 10, (2, 1, 4), dtype=torch.int)

In [59]:
a

tensor([[[0, 7, 1, 3],
         [5, 9, 7, 4],
         [3, 4, 2, 8]],

        [[5, 6, 9, 4],
         [8, 0, 3, 9],
         [2, 7, 0, 2]]], dtype=torch.int32)

In [60]:
b

tensor([[[2, 9, 0, 2]],

        [[7, 0, 0, 8]]], dtype=torch.int32)

In [61]:
a+b

tensor([[[ 2, 16,  1,  5],
         [ 7, 18,  7,  6],
         [ 5, 13,  2, 10]],

        [[12,  6,  9, 12],
         [15,  0,  3, 17],
         [ 9,  7,  0, 10]]], dtype=torch.int32)

In [62]:
(a+b).shape

torch.Size([2, 3, 4])

In [63]:
a[:2,:,:].shape

torch.Size([2, 3, 4])