In [1]:
import torch
import torch.nn as nn
from einops import rearrange

from timm.models.layers import trunc_normal_, DropPath

import math

Import the customized embeddings, since we removed the CLS and other useless part from our code

In [2]:
# This would be done when later the Vit_Linear_Encoder is transfered into a .py file

In [3]:
from torchvision import transforms

import os
import numpy as np 

from typing import Optional, List, Union

import matplotlib.pyplot as plt 
import matplotlib

from tqdm.notebook import tqdm

from PIL import Image

from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

In [4]:
# Initializing a ViT & BERT style configuration
config_encoder = ViTConfig()
config_decoder = BertConfig()

config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# Initializing a ViTBert model from a ViT & bert-base-uncased style configurations
model = VisionEncoderDecoderModel(config=config)

# Accessing the model configuration
config_encoder = model.config.encoder
# print(model.encoder.encoder.layer[0])
# config_decoder  = model.config.decoder
# set decoder config to causal lm
# config_decoder.is_decoder = True
# config_decoder.add_cross_attention = True

ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)


In [5]:
class ViTPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self):
        super().__init__()
        image_size, patch_size = 224, 16
        num_channels, hidden_size = 3, 768

        image_size = (image_size, image_size)
        patch_size = (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        # the shape would be [batch_size, hidden_size, height//patch_size, width//patch_size] before the flatten operation, after flatten we have [batch_size, hidden_size, num_patches]
        # then after the transpose operation, we have [batch_size, num_patches, hidden_size]
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return embeddings

In [6]:
class LinearViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    """

    def __init__(self) -> None:
        super().__init__()

        # self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
        # self.mask_token = nn.Parameter(torch.zeros(1, 1, 768)) if use_mask_token else None
        self.patch_embeddings = ViTPatchEmbeddings()
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, 768))
        self.dropout = nn.Dropout(0., inplace=False)

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        
        # remove the cls token, since we are only using the ViT Encoder to encode the image
        # If no interpolation is needed, return the position embeddings as they are
        num_patches = embeddings.shape[1]
        num_positions = self.position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        # the following two lines of code are commented out because we are not using the class token
        # class_pos_embed = self.position_embeddings[:, 0]
        # patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        h0 = height // 16
        w0 = width // 16
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        h0, w0 = h0 + 0.1, w0 + 0.1
        # the int function in the following function is to make sure h and w are integers, so the weird number problem after we add some special tokens will be solved
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape

        # the parameters should be kept since they are calling the forward function of the ViTPatchEmbeddings class
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
        
        # comment out the CLS token since we are not using it
        # add the [CLS] token to the embedded patch tokens
        # cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # change the embedding length if it is longer than the default length, use the interpolation function defined earlier
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings

In [7]:
# Debug

test_embedding = torch.randn(1, 3, 224, 224)
linearembedding = LinearViTEmbeddings()
with torch.no_grad():
    output = linearembedding(test_embedding)
    print(output.shape)

torch.Size([1, 196, 768])


In [8]:
"""
from transformers import AutoImageProcessor, ViTModel
import torch
from datasets import load_dataset

image = torch.randn(1, 3, 224, 224)
# image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

# print(inputs.size, "after the transformation")

with torch.no_grad():
    outputs = model(image)

print(outputs.last_hidden_state.shape)
"""

'\nfrom transformers import AutoImageProcessor, ViTModel\nimport torch\nfrom datasets import load_dataset\n\nimage = torch.randn(1, 3, 224, 224)\n# image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")\nmodel = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")\n\n# print(inputs.size, "after the transformation")\n\nwith torch.no_grad():\n    outputs = model(image)\n\nprint(outputs.last_hidden_state.shape)\n'

In [9]:
###print(model.embeddings)

"""
test_conv = nn.Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
test_result = test_conv(torch.randn(1, 3, 224, 224))
print(test_result.shape)
"""

'\ntest_conv = nn.Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\ntest_result = test_conv(torch.randn(1, 3, 224, 224))\nprint(test_result.shape)\n'

In [10]:
"""
# Create a toy dataset
batch_size = 10
num_patches = 16 
dim = 128  
H, W = 224, 224
channels = 3

# Create a mock input tensor
input_tensor = torch.randn(32, 3, H, W)

# Instantiate the model
model = VisionEncoderDecoderModel(config=config)

# Run a forward pass
with torch.no_grad():
    output = model(input_tensor)

# Check the output shape
print("Input shape:", input_tensor.shape)
print("Output shape:", output.shape)

# Verify the shape correctness (This will depend on your specific requirements)
expected_output_shape = (batch_size, num_patches, dim)  # Example expected shape
assert output.shape == expected_output_shape, f"Output shape {output.shape} does not match expected shape {expected_output_shape}"
"""

'\n# Create a toy dataset\nbatch_size = 10\nnum_patches = 16 \ndim = 128  \nH, W = 224, 224\nchannels = 3\n\n# Create a mock input tensor\ninput_tensor = torch.randn(32, 3, H, W)\n\n# Instantiate the model\nmodel = VisionEncoderDecoderModel(config=config)\n\n# Run a forward pass\nwith torch.no_grad():\n    output = model(input_tensor)\n\n# Check the output shape\nprint("Input shape:", input_tensor.shape)\nprint("Output shape:", output.shape)\n\n# Verify the shape correctness (This will depend on your specific requirements)\nexpected_output_shape = (batch_size, num_patches, dim)  # Example expected shape\nassert output.shape == expected_output_shape, f"Output shape {output.shape} does not match expected shape {expected_output_shape}"\n'

Only Problem right now: Drop path use or not?

In [11]:
class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H = 224, W = 224):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

ViTAttention(
  (attention): ViTSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (output): ViTSelfOutput(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
)

In [12]:
class FocusedLinearAttention(nn.Module):
    def __init__(self, dim=768, num_patches=196, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 linear=False, focusing_factor=3, kernel_size=5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.query = nn.Linear(dim, dim, bias=qkv_bias)
        self.key = nn.Linear(dim, dim, bias=qkv_bias)
        self.value = nn.Linear(dim, dim, bias=qkv_bias)
        self.dropout = nn.Dropout(attn_drop, inplace=False)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.linear = linear

        self.pool = nn.AdaptiveAvgPool2d(7)
        self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
        self.norm = nn.LayerNorm(dim)
        self.act = nn.GELU()

        self.focusing_factor = focusing_factor
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
        self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches, dim)))
        print('Linear Attention f{} kernel{}'.
              format(focusing_factor, kernel_size))

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H = 14, W = 14, head_mask=None, output_attentions=False):
        B, N, C = x.shape
        print(x.shape)
        print('B': B, 'N': N, 'C': C)
        q = self.query(x)
        print(q.shape, type(q))
        k = self.key(x)
        v = self.value(x)
        x = x.permute(0, 2, 1).reshape(B, C, H, W)

        k = k + self.positional_encoding
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        scale = nn.Softplus()(self.scale)
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        q = q ** focusing_factor
        k = k ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]

        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)

        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [13]:
attention = FocusedLinearAttention()

print(attention)

Linear Attention f3 kernel5
FocusedLinearAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=768, out_features=768, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
  (pool): AdaptiveAvgPool2d(output_size=7)
  (sr): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (act): GELU(approximate='none')
  (dwc): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=96)
)


In [14]:
# Debug

test_image = torch.randn(32, 196, 768)
attention = FocusedLinearAttention()
with torch.no_grad():
    output = attention(test_image)
    print(output.shape)

Linear Attention f3 kernel5
torch.Size([32, 196, 768])
torch.Size([32, 196, 768])


In [15]:
# change the attention mechanism in the encoder
for layer in model.encoder.encoder.layer:
    # Replace the attention mechanism
    layer.attention = FocusedLinearAttention()

Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5
Linear Attention f3 kernel5


Successfully changed

In [16]:
# changed the embedding in the encoder into our own LinearViTEmbeddings
model.encoder.embeddings = LinearViTEmbeddings()

In [17]:
# Debug if the embedding is changed
# Remember to remove this chunk in the final github version

print(model.encoder.embeddings)

LinearViTEmbeddings(
  (patch_embeddings): ViTPatchEmbeddings(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
)


In [18]:
# Debug if the attention mechanism has been replaced
# Remeber to remove this chunk in the final github version

# print(model.encoder.encoder.layer[0].attention)

Final Test, load the pretrained weights into the model, if this returns no error then this whole stuff would work

In [19]:
from transformers import ViTModel
pretrained_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")


# Prepare to load weights
pretrained_dict = pretrained_model.state_dict()
custom_dict = model.state_dict()

# Filter out unnecessary keys and update custom model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in custom_dict and 'attention' in k}
custom_dict.update(pretrained_dict)

# Load the updated state dict into custom model
model.load_state_dict(custom_dict, strict=False)

# Save the model
# model.save_pretrained("model_weights/my-vit-bert")

print(model.encoder.encoder.layer[0].attention)

FocusedLinearAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=768, out_features=768, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
  (pool): AdaptiveAvgPool2d(output_size=7)
  (sr): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (act): GELU(approximate='none')
  (dwc): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=96)
)


In [20]:
from datasets import load_dataset

In [21]:
# Debug
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
test_image = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = pretrained_model(test_image)
    print(output.last_hidden_state.shape)

torch.Size([1, 197, 768])


In [25]:
with torch.no_grad():
    output = model(test_image)
    print(output.shape)

torch.Size([1, 196, 768])


TypeError: reshape(): argument 'shape' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got NoneType"

In [None]:
# Extract weights from the pre-trained model
pretrained_weights = {}
for name, param in model.named_parameters():
    if 'attention.self.query' in name or 'attention.self.key' in name or 'attention.self.value' in name or 'attention.self.dropout' in name:
        pretrained_weights[name] = param

In [None]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.linear = linear
        if self.linear:
            self.relu = nn.ReLU(inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        if self.linear:
            x = self.relu(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x



class Block(nn.Module):

    def __init__(self, dim, num_patches, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,linear=False,
                 focusing_factor=3, kernel_size=5):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = FocusedLinearAttention(
                dim, num_patches,
                num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                attn_drop=attn_drop, proj_drop=drop, linear=linear,
                focusing_factor=focusing_factor, kernel_size=kernel_size)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x

In [None]:
model.encoder.encoder.layer[0].attention.attention

In [None]:
model.encoder.encoder.layer[0].attention.attention  =  FocusedLinearAttention(dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
                 linear=False, focusing_factor=3, kernel_size=5)

In [None]:
# Saving the model, including its configuration
model.save_pretrained('my-model')

# loading model and config from pretrained folder
encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained('my-model')
model = VisionEncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)

In [None]:
# Create a toy dataset
batch_size = 10
num_patches = 16 
dim = 128  
H, W = 4, 4 

# Create a mock input tensor
input_tensor = torch.randn(batch_size, num_patches, dim)

# Instantiate the model
model = Block(dim=dim, num_patches=num_patches, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, linear=False, focusing_factor=3, kernel_size=5)

# Run a forward pass
output = model(input_tensor, H, W)

# Check the output shape
print("Input shape:", input_tensor.shape)
print("Output shape:", output.shape)

# Verify the shape correctness (This will depend on your specific requirements)
expected_output_shape = (batch_size, num_patches, dim)  # Example expected shape
assert output.shape == expected_output_shape, f"Output shape {output.shape} does not match expected shape {expected_output_shape}"