In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import collections.abc

def to_2tuple(x):
  if isinstance(x, collections.abc.Iterable):
    return x
  return x, x

img_size = 224
patch_size = 16

img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)

img_size, patch_size

((224, 224), (16, 16))

In [None]:
x = torch.rand((32, 3, 224, 224))
x.shape

torch.Size([32, 3, 224, 224])

In [None]:
proj = nn.Conv2d(3, 768, 16, 16)
y = proj(x)
y.shape

torch.Size([32, 768, 14, 14])

In [None]:
y.flatten(2).shape

torch.Size([32, 768, 196])

In [None]:
y.flatten(2).transpose(1, 2).shape

torch.Size([32, 196, 768])

In [None]:
class PatchEmbeddings(nn.Module):
    """
    Image to Patch Embedding.

    """

    def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
        super().__init__()
        image_size = to_2tuple(image_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values):
        batch_size, num_channels, height, width = pixel_values.shape
        # FIXME look at relaxing size constraints
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return x


In [None]:
cls_token = nn.Parameter(torch.zeros(1, 1, 768))
cls_token.shape

torch.Size([1, 1, 768])

In [None]:
cls_token.expand(32, -1, -1).shape

torch.Size([32, 1, 768])

In [None]:
position_embedding = nn.Parameter(torch.zeros(1, 14*14 + 1, 768))
position_embedding.shape

torch.Size([1, 197, 768])

In [None]:
class ViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings.

    """

    def __init__(self, config):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.patch_embeddings = PatchEmbeddings(
            image_size=config.image_size,
            patch_size=config.patch_size,
            num_channels=config.num_channels,
            embed_dim=config.hidden_size,
        )
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, pixel_values):
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values)

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
class ViTConfig():
  def __init__(
        self,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        is_encoder_decoder=False,
        image_size=224,
        patch_size=16,
        num_channels=3,
        **kwargs
    ):

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels


configuration = ViTConfig()
# You can read full configuration file here: https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/configuration_vit.py

In [None]:
vars(configuration)

{'attention_probs_dropout_prob': 0.0,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 768,
 'image_size': 224,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'layer_norm_eps': 1e-12,
 'num_attention_heads': 12,
 'num_channels': 3,
 'num_hidden_layers': 12,
 'patch_size': 16}

In [None]:
x = torch.rand((32, 3, 224, 224))
vit_emb = ViTEmbeddings(configuration)
vit_emb

ViTEmbeddings(
  (patch_embeddings): PatchEmbeddings(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
)

In [None]:
out = vit_emb(x)
out.shape

torch.Size([32, 197, 768])

In [None]:
mat = nn.Linear(768, 12*64)
mat = mat(out)
mat.size()

torch.Size([32, 197, 768])

In [None]:
mat.size()[:-1]

torch.Size([32, 197])

In [None]:
new_shape = mat.size()[:-1] + (12, 64)
new_shape

torch.Size([32, 197, 12, 64])

In [None]:
print(out.shape)
out = out.view(*new_shape)
print(out.shape)

torch.Size([32, 197, 768])
torch.Size([32, 197, 12, 64])


In [None]:
out = out.permute(0, 2, 1, 3)
print(out.shape)

torch.Size([32, 12, 197, 64])


In [None]:
out2 = out

In [None]:
torch.matmul(out, out2)

RuntimeError: ignored

In [None]:
out2.transpose(-1, -2).shape

torch.Size([32, 12, 64, 197])

In [None]:
attention_scores = torch.matmul(out, out2.transpose(-1, -2))
attention_scores.shape

torch.Size([32, 12, 197, 197])

In [None]:
context_layer = torch.matmul(nn.Softmax(dim=-1)(attention_scores), out)
context_layer.shape

torch.Size([32, 12, 197, 64])

In [None]:
context_layer = context_layer.permute(0, 2, 1, 3)
context_layer.shape

torch.Size([32, 197, 12, 64])

In [None]:
context_layer.size()[:-2]

torch.Size([32, 197])

In [None]:
new_context_layer_shape = context_layer.size()[:-2] + (12*64,)
new_context_layer_shape

torch.Size([32, 197, 768])

In [None]:
print(context_layer.shape)
context_layer.view(*new_context_layer_shape)
print(context_layer.shape)

torch.Size([32, 197, 12, 64])


RuntimeError: ignored

In [None]:
context_layer = torch.matmul(nn.Softmax(dim=-1)(attention_scores), out)
print(context_layer.shape)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
print(context_layer.shape)
print(context_layer.size()[:-2])
new_context_layer_shape = context_layer.size()[:-2] + (12*64,)
print(new_context_layer_shape)
print(context_layer.shape)
context_layer.view(*new_context_layer_shape)
print(context_layer.shape)


torch.Size([32, 12, 197, 64])
torch.Size([32, 197, 12, 64])
torch.Size([32, 197])
torch.Size([32, 197, 768])
torch.Size([32, 197, 12, 64])
torch.Size([32, 197, 12, 64])


In [None]:
x = torch.randn(3, 2)
y = torch.transpose(x, 0, 1)
print(x)
print(y)

tensor([[-0.1447,  0.3310],
        [ 0.9052, -0.9171],
        [-1.2232, -0.5425]])
tensor([[-0.1447,  0.9052, -1.2232],
        [ 0.3310, -0.9171, -0.5425]])


In [None]:
x[0, 1] = 42
print(x)
print(y)
print(x.is_contiguous())
print(y.is_contiguous())

tensor([[-0.1447, 42.0000],
        [ 0.9052, -0.9171],
        [-1.2232, -0.5425]])
tensor([[-0.1447,  0.9052, -1.2232],
        [42.0000, -0.9171, -0.5425]])
True
False


This is where the concept of contiguous comes in. In the example above, x is contiguous but y is not because its memory layout is different to that of a tensor of same shape made from scratch. Note that the word "contiguous" is a bit misleading because it's not that the content of the tensor is spread out around disconnected blocks of memory. Here bytes are still allocated in one block of memory but the order of the elements is different!

When you call contiguous(), it actually makes a copy of the tensor such that the order of its elements in memory is the same as if it had been created from scratch with the same data.

In [None]:
import math
class ViTSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)

        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

In [None]:
vit_atn = ViTSelfAttention(configuration)
vit_atn

ViTSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [None]:
x = torch.rand((32, 3, 224, 224))
vit_emb = ViTEmbeddings(configuration)
vit_emb

ViTEmbeddings(
  (patch_embeddings): PatchEmbeddings(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
)

In [None]:
emb = vit_emb(x)
emb.shape

torch.Size([32, 197, 768])

In [None]:
context_layer, attention_probs = vit_atn(emb, head_mask=None, output_attentions=True)

In [None]:
context_layer.shape, attention_probs.shape

(torch.Size([32, 197, 768]), torch.Size([32, 12, 197, 197]))

In [None]:
class ViTSelfOutput(nn.Module):
  """
  This is just a Linear Layer Block
  """
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)

    return hidden_states

In [None]:
class ViTAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = ViTSelfAttention(config)
        self.output = ViTSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

In [None]:
class ViTIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states):

        hidden_states = self.dense(hidden_states)
        hidden_states = nn.functional.gelu(hidden_states)

        return hidden_states


In [None]:
class ViTOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)

        hidden_states = hidden_states + input_tensor

        return hidden_states

In [None]:
class ViTLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config):
        super().__init__()
        self.seq_len_dim = 1
        self.attention = ViTAttention(config)
        self.intermediate = ViTIntermediate(config)
        self.output = ViTOutput(config)
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection
        hidden_states = attention_output + hidden_states

        # in ViT, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(hidden_states)

        layer_output = self.intermediate(layer_output)

        # second residual connection is done here
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output)
        return layer_output

In [None]:
class ViTEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None

            if getattr(self.config, "gradient_checkpointing", False) and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    layer_head_mask,
                )
            else:
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        
        return hidden_states,all_hidden_states,all_self_attentions


In [None]:
class ViTModel():
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = ViTEmbeddings(config)
        self.encoder = ViTEncoder(config)

        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = ViTPooler(config) if add_pooling_layer else None

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings


    def forward(
        self,
        pixel_values=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:

        Examples::

            >>> from transformers import ViTFeatureExtractor, ViTModel
            >>> from PIL import Image
            >>> import requests

            >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
            >>> image = Image.open(requests.get(url, stream=True).raw)

            >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
            >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

            >>> inputs = feature_extractor(images=image, return_tensors="pt")
            >>> outputs = model(**inputs)
            >>> last_hidden_states = outputs.last_hidden_state
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return sequence_output,pooled_output,encoder_outputs.hidden_states,encoder_outputs.attentions


In [None]:
class ViTPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [None]:
out.shape, configuration

(torch.Size([32, 12, 197, 64]), <__main__.ViTConfig at 0x7f31833ef090>)

In [None]:
vit_enc = ViTEncoder(configuration)

In [None]:
vit_enc

ViTEncoder(
  (layer): ModuleList(
    (0): ViTLayer(
      (attention): ViTAttention(
        (attention): ViTSelfAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (output): ViTSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (intermediate): ViTIntermediate(
        (dense): Linear(in_features=768, out_features=3072, bias=True)
      )
      (output): ViTOutput(
        (dense): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine

In [None]:
input = torch.rand((32, 3, 224, 224))
embeddings = ViTEmbeddings(configuration)
encoder = ViTEncoder(configuration)
layernorm = nn.LayerNorm(config.hidden_size, eps=0.000001)

In [None]:
embedding_output = embeddings(input)
embedding_output.shape

torch.Size([32, 197, 768])

In [None]:
encoder_output = encoder(embedding_output)

In [None]:
type(encoder_output), len(encoder_output)

(tuple, 3)

In [None]:
hidden_states, all_hidden_states, all_self_attentions = encoder_output

hidden_states.shape, all_hidden_states, all_self_attentions

(torch.Size([32, 197, 768]), None, None)

In [None]:
sequence_output = encoder_output[0]
sequence_output.shape

torch.Size([32, 197, 768])

In [None]:
sequence_output[:, 0].shape

torch.Size([32, 768])

In [None]:
sequence_output = encoder_output[0]
layernorm = nn.LayerNorm(config.hidden_size, eps=0.00001)
sequence_output = layernorm(sequence_output)
# VitPooler
dense = nn.Linear(config.hidden_size, config.hidden_size)
activation = nn.Tanh()
first_token_tensor = sequence_output[:, 0]
pooled_output = dense(first_token_tensor)
pooled_output = activation(pooled_output)
pooled_output.shape


torch.Size([32, 768])

In [None]:
classifier = nn.Linear(config.hidden_size, 100)
logits = classifier(pooled_output)
logits.shape

torch.Size([32, 100])