In [1]:
import torch
import torch.nn as nn

import math

from typing import Optional, Union, List, Dict, Tuple, Set, Any

In [2]:
class ViTPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self):
        super().__init__()
        image_size, patch_size = 224, 16
        num_channels, hidden_size = 3, 768

        image_size = (image_size, image_size)
        patch_size = (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        # the shape would be [batch_size, hidden_size, height//patch_size, width//patch_size] before the flatten operation, after flatten we have [batch_size, hidden_size, num_patches]
        # then after the transpose operation, we have [batch_size, num_patches, hidden_size]
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return embeddings

In [3]:
test_embedding = ViTPatchEmbeddings()
pixel_values = torch.rand(1, 3, 224, 224)
with torch.no_grad():
    embeddings = test_embedding(pixel_values)
print(embeddings.shape)

torch.Size([1, 196, 768])


In [3]:
class ViTSelfAttention(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()

        self.num_attention_heads = 12
        self.attention_head_size = int(768 / 12)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(768, self.all_head_size, bias=True)
        self.key = nn.Linear(768, self.all_head_size, bias=True)
        self.value = nn.Linear(768, self.all_head_size, bias=True)

        self.dropout = nn.Dropout(0., inplace=False)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        mixed_query_layer = self.query(hidden_states)

        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

In [4]:
test_image = torch.randn(32, 196, 768)
vitattention = ViTSelfAttention()
with torch.no_grad():
    output = vitattention(test_image)
    print(output[0].shape)

torch.Size([32, 196, 768])


In [5]:
class ViTSelfOutput(nn.Module):
    """
    The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self,) -> None:
        super().__init__()
        self.dense = nn.Linear(768, 768)
        self.dropout = nn.Dropout(0.1, inplace=False)

    def forward(self, hidden_states: torch.Tensor, input_tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)

        return hidden_states

In [6]:
test_image2 = torch.randn(32, 196, 768)
vitoutput = ViTSelfOutput()
with torch.no_grad():
    output = vitoutput(test_image2, test_image2)
    print(output.shape)

torch.Size([32, 196, 768])


In [7]:
class ViTAttention(nn.Module):
    def __init__(self,) -> None:
        super().__init__()
        self.attention = ViTSelfAttention()
        self.output = ViTSelfOutput()

    def forward(self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False,):
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

In [8]:
test_image_3 = torch.randn(32, 196, 768)
vitattention_whole = ViTAttention()
with torch.no_grad():
    output = vitattention_whole(test_image_3)
    print(output[0].shape)

torch.Size([32, 196, 768])


In [9]:
class ViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    """

    def __init__(self) -> None:
        super().__init__()

        # self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
        # self.mask_token = nn.Parameter(torch.zeros(1, 1, 768)) if use_mask_token else None
        self.patch_embeddings = ViTPatchEmbeddings()
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, 768))
        self.dropout = nn.Dropout(0., inplace=False)

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        
        # remove the cls token, since we are only using the ViT Encoder to encode the image
        # If no interpolation is needed, return the position embeddings as they are
        num_patches = embeddings.shape[1]
        num_positions = self.position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        # the following two lines of code are commented out because we are not using the class token
        # class_pos_embed = self.position_embeddings[:, 0]
        # patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        h0 = height // 16
        w0 = width // 16
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        h0, w0 = h0 + 0.1, w0 + 0.1
        # the int function in the following function is to make sure h and w are integers, so the weird number problem after we add some special tokens will be solved
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape

        # the parameters should be kept since they are calling the forward function of the ViTPatchEmbeddings class
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
        
        # comment out the CLS token since we are not using it
        # add the [CLS] token to the embedded patch tokens
        # cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # change the embedding length if it is longer than the default length, use the interpolation function defined earlier
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings

In [10]:
test_image = torch.randn(1, 3, 224, 224)
embedding = ViTPatchEmbeddings()
with torch.no_grad():
    output = embedding(test_image)
    print(output.shape)

torch.Size([1, 196, 768])


In [11]:
test_image2 = torch.randn(1, 3, 224, 224)
embedding = ViTEmbeddings()
with torch.no_grad():
    output = embedding(test_image2)
    print(output.shape)

torch.Size([1, 196, 768])
