In [9]:
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        """
        Args:
            img_size (int): Size of the input image (assuming square image).
            patch_size (int): Size of each patch (assuming square patches).
            in_chans (int): Number of input channels (e.g., 3 for RGB images).
            embed_dim (int): Dimension of the embedding space.
        """
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Convolutional layer to project each patch to the embedding dimension
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_chans, img_size, img_size).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, n_patches, embed_dim).
        """
        # Apply the convolutional layer to project patches to the embedding dimension
        x = self.proj(x)  # [batch_size, embed_dim, n_patches**0.5, n_patches**0.5]
        
        # Flatten the spatial dimensions (height and width) into one dimension
        x = x.flatten(2)  # [batch_size, embed_dim, n_patches]
        
        # Transpose the tensor to get the shape [batch_size, n_patches, embed_dim]
        x = x.transpose(1, 2)  # [batch_size, n_patches, embed_dim]
        
        return x

# Define the parameters
img_size = 224
patch_size = 16
in_chans = 3
embed_dim = 768

# Create a dummy input tensor with shape (batch_size, in_chans, img_size, img_size)
dummy_input = torch.randn(1, in_chans, img_size, img_size)

# Instantiate the PatchEmbed class
patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)

# Pass the dummy input through the model
output = patch_embed(dummy_input)

# Print the shape of the output
print("Output shape:", output.shape)


Output shape: torch.Size([1, 196, 768])


In [11]:
import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0, proj_p=0):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self, x):
        qkv = self.qkv(x)  # [batch_size, n_patches+1, dim*3]
        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(2, 3)) * self.scale
        attn = nn.Softmax(dim=-1)(attn)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x

# Example usage to test the Attention class
batch_size = 4
n_patches = 16
dim = 768

# Create a dummy input tensor
x = torch.randn(batch_size, n_patches + 1, dim)  # Shape: [batch_size, n_patches+1, dim]

# Initialize the Attention module
attention = Attention(dim=dim)

# Pass the input tensor through the Attention module
output = attention(x)

# Print the output shape to verify
print("Output shape:", output.shape)



Output shape: torch.Size([4, 17, 768])
