In [1]:
import torch
import torch.nn as nn

In [55]:
class PatchEmbed(nn.Module):
    """Split image into patches and then embed them.

    Parameters
    ----------
    img_size : int
        Size of the image (it is a square).

    patch_size : int
        Size of the patch (it is a square).

    in_chans : int
        Number of input channels.

    embed_dim : int
        The embedding dimension.

    Attributes
    ----------
    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding.
    """
    def __init__(self, img_size, patch_size, in_channels = 3, embed_dim = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size = patch_size, stride = patch_size,)

    def forward(self, x):
        """ Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape '(n_samples, in_chans, img_size, img_size)'.

        Returns
        --------
        torch.Tensor
            Shape '(n_samples, n_patches, embed_dim)'.
        """
        x = self.proj(x) # (n_samples, emed_dim, n_patches ** 0.5, n_patches ** 0.5)
        print(x.shape)
        x = x.flatten(2) # (n_samples, emed_dim, n_patches)
        print(x.shape)
        x = x.transpose(1,2) # n_samples, n_patches, embed_dim)
        print(x.shape)

        return x

In [63]:
x = torch.randn(1,3,224,224)

model = PatchEmbed(img_size=224,patch_size=16,embed_dim=768,in_channels=3)
model(x).shape

torch.Size([1, 768, 14, 14])
torch.Size([1, 768, 196])
torch.Size([1, 196, 768])


torch.Size([1, 196, 768])

In [72]:
class Attention(nn.Module):
    """ Attention mechanism

    Parameters
    ----------
    dim : int
        The input and out dimension of per token features.

    n_heads: int
        Number of attention heads.

    qkv_bias : bool
        If True, then we include bias to the query, key and value projections.

    attn_p : float
        Dropout probability applied to the query, key and value tensors.

    proj_p : float
        Dropout probability applied to the output tensor.
    """
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """ Run forward pass.

        Parameters
        ----------
        x: torch.Tensor
            Shape '(n_samples, n_patches +1, dim)'.

        Returns 
        ------
        torch.Tensor
            Shape '(n_samples, n_patches +1, dim)'.
        """
        n_samples, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x) # (n_samples, n_patches +1, 3 * dim)
        qkv = qkv.reshape(
                n_samples, n_tokens, 3, self, self.n_heads, self.head_dim)
        
        qkv = qkv.permute(
            2, 0, 3, 1, 4)
        # (3, n_samples, n_heads, n_patches + 1,head_dim )
        q, k, v = qkv[0], qkv[1], qkv[2]

        k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches +1)

        dp = (q @ k_t) * self.scale

        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x
    

In [73]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        return self.fc2(self.drop(self.act(self.fc1(x))))             
    

In [67]:
x = torch.Tensor([1,2,3])
x@x

tensor(14.)

In [57]:
(224/16) **2

196.0

In [58]:
224 * 2

448

In [70]:
x.softmax(dim=-1)

tensor([0.0900, 0.2447, 0.6652])

In [47]:
 x= conv(x)
x.shape

torch.Size([1, 768, 8, 8])

In [20]:
x = x.flatten(2)
x.shape

torch.Size([1, 768, 64])

In [21]:
x = x.transpose(1,2)
x.shape

torch.Size([1, 64, 768])

In [22]:
28*28

784

In [24]:
torch.sqrt(torch.tensor(768))

tensor(27.7128)

In [26]:
p = nn.Dropout2d(0.5)
p(x)



tensor([[[-0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000,  0.0000],
         [-1.1850,  2.0897, -0.2977,  ...,  0.4573, -0.8970,  1.1202],
         ...,
         [-0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0000],
         [-0.6121, -0.3459, -1.9492,  ...,  0.4251, -0.6151,  0.2304],
         [-0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0000]]],
       grad_fn=<MulBackward0>)

https://towardsdatascience.com/position-embeddings-for-vision-transformers-explained-a6f9add341d5

In [61]:
p.training
        

True

In [27]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Create a tensor of shape [max_len, dim] for encoding
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).float().unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))  # [dim/2]

        # Apply sin and cos functions for positional encoding
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices (sine)
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices (cosine)

        pe = pe.unsqueeze(0)  # Shape becomes [1, max_len, dim]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to the input tensor
        return x + self.pe[:, :x.size(1)]  # x.size(1) is the length of the sequence

# Example usage for ViT
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, dim=768, num_classes=1000):
        super(VisionTransformer, self).__init__()
        
        self.patch_size = patch_size
        self.dim = dim
        self.num_classes = num_classes

        # Calculate the number of patches
        self.num_patches = (img_size // patch_size) ** 2

        # Define the embedding layer for patches
        self.patch_embeddings = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(dim, max_len=self.num_patches)

        # Transformer layers (simplified version)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=8),
            num_layers=12
        )

        # Classification head
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        # Extract patches and embed them
        x = self.patch_embeddings(x)  # Shape: [batch_size, dim, num_patches, num_patches]
        x = x.flatten(2).transpose(1, 2)  # Shape: [batch_size, num_patches, dim]

        # Add positional encoding
        x = self.positional_encoding(x)

        # Pass through transformer encoder
        x = self.encoder(x)

        # Classification head (take the output of the [CLS] token)
        x = x.mean(dim=1)  # Global average pooling

        # Final classification layer
        x = self.fc(x)
        return x

# Example input (batch_size=2, img_size=224)
model = VisionTransformer(img_size=224, patch_size=16)
sample_input = torch.randn(2, 3, 224, 224)  # Batch of 2 images with 3 channels (RGB)
output = model(sample_input)
print(output.shape)  # Should output: torch.Size([2, 1000])




torch.Size([2, 1000])


In [29]:
torch.zeros(5000, 768).shape

torch.Size([5000, 768])

In [32]:
position = torch.arange(0, 5000).float().unsqueeze(1)
position.shape

torch.Size([5000, 1])