In [1]:
import torch
import torch.nn as nn

import torch.nn.functional as F
import math
from collections import defaultdict

from timm.models.layers import trunc_normal_
from timm.models.layers import DropPath
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import _load_weights

import torch
import torch.nn as nn
from einops import rearrange
from pathlib import Path

import torch.nn.functional as F

from timm.models.layers import DropPath


class FeedForward1(nn.Module):
    def __init__(self, dim, hidden_dim, dropout, out_dim=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        if out_dim is None:
            out_dim = dim
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, heads, dropout):
        super().__init__()
        self.heads = heads
        head_dim = dim // heads
        self.scale = head_dim ** -0.5
        self.attn = None

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x, mask=None):
        B, N, C = x.shape
        print(x.shape)
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.heads, C // self.heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn


class Block(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, dropout)
        self.mlp = FeedForward1(dim, mlp_dim, dropout)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x, mask=None, return_attention=False):
        y, attn = self.attn(self.norm1(x), mask)
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)


def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
    posemb_tok, posemb_grid = (
        posemb[:, :num_extra_tokens],
        posemb[0, num_extra_tokens:],
    )
    if grid_old_shape is None:
        gs_old_h = int(math.sqrt(len(posemb_grid)))
        gs_old_w = gs_old_h
    else:
        gs_old_h, gs_old_w = grid_old_shape

    gs_h, gs_w = grid_new_shape
    posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, channels):
        super().__init__()

        self.image_size = image_size
        print(image_size)
        if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
            raise ValueError("image dimensions must be divisible by the patch size")
        self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.patch_size = patch_size
        print("Patch embedding embed dim", embed_dim)

        self.proj = nn.Conv2d(
            channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, im):
        #B, T, C, H, W = im.shape
        B, C, H, W = im.shape
        #print("efore rearragne",im.shape)
        #im = rearrange(im,
                #"b t c (h ph) (w pw) -> (b t) c (h ph) (w pw)",
                #ph=16,
                #pw=16,
            #)
        print("Patch Embedding im shape", im.shape)
        x = self.proj(im).flatten(2).transpose(1, 2)
        #x=rearrange(x, "(b t) h w -> b t h w ",b=1,t=4)
        
        #x = self.proj(im)
        #x=rearrange(x, "b c h w -> b (h w) c ")
        print("Patch embedding After proj operation x shape", x.shape)
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size,
        patch_size,
        n_layers,
        d_model,
        d_ff,
        n_heads,
        n_cls,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            image_size,
            patch_size,
            d_model,
            channels,
        )
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.n_cls = n_cls

        # cls and pos tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.distilled = distilled
        if self.distilled:
            self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 2, d_model)
            )
            self.head_dist = nn.Linear(d_model, n_cls)
        else:
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 1, d_model)
            )

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
        )

        # output head
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_cls)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        if self.distilled:
            trunc_normal_(self.dist_token, std=0.02)
        self.pre_logits = nn.Identity()

        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token"}

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    def forward(self, im, return_features=False):
        print("Vision transformer im shape", im.shape)
        B,T, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        print("ViT after patch embed operation", x.shape)
        cls_tokens = self.cls_token.expand(B, -1,-1)
        print("ViT cls tokens", cls_tokens.shape)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)
            print("ViT after cat operatiom", x.shape)

        pos_embed = self.pos_embed
        print("ViT pos_embed", pos_embed.shape)
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed
        print("ViT after adding pos embedding", x.shape)
        x = self.dropout(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        if return_features:
            return x

        """if self.distilled:
            x, x_dist = x[:, 0], x[:, 1]
            x = self.head(x)
            x_dist = self.head_dist(x_dist)
            x = (x + x_dist) / 2
        else:
            x = x[:, 0]
            x = self.head(x)"""
        print("ViT Last step of encoder", x.shape)
        return x

    def get_attention_map(self, im, layer_id):
        if layer_id >= self.n_layers or layer_id < 0:
            raise ValueError(
                f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
            )
        B, _, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        pos_embed = self.pos_embed
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed

        for i, blk in enumerate(self.blocks):
            if i < layer_id:
                x = blk(x)
            else:
                return blk(x, return_attention=True)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
def pair(t):
    """
    Parameters
    ----------
    t: tuple[int] or int
    """
    return t if isinstance(t, tuple) else (t, t)

class BaseClassificationModel(nn.Module):
    """
    Parameters
    -----------
    img_size: int
        Size of the image
    patch_size: int or tuple(int)
        Size of the patch
    in_channels: int
        Number of channels in input image
    pool: str
        Feature pooling type, must be one of {``mean``, ``cls``}
    """

    def __init__(self, img_size, patch_size, in_channels=3, pool="cls"):
        super(BaseClassificationModel, self).__init__()

        img_height, img_width = pair(img_size)
        patch_height, patch_width = pair(patch_size)

        assert (
            img_height % patch_height == 0 and img_width % patch_width == 0
        ), "Image dimensions must be divisible by the patch size."

        num_patches = (img_height // patch_height) * (img_width // patch_width)
        patch_dim = in_channels * patch_height * patch_width

        self.patch_height = patch_height
        self.patch_width = patch_width
        self.num_patches = num_patches
        self.patch_dim = patch_dim

        assert pool in {
            "cls",
            "mean",
        }, "Feature pooling type must be either cls (cls token) or mean (mean pooling)"
        self.pool = pool
        
class ViViTModel2(BaseClassificationModel):
    """
    Model 2 implementation of: `ViViT: A Video Vision Transformer <https://arxiv.org/abs/2103.15691>`_
    Parameters
    -----------
    img_size:int
        Size of single frame/ image in video
    in_channels:int
        Number of channels
    patch_size: int
        Patch size
    embedding_dim: int
        Embedding dimension of a patch
    num_frames:int
        Number of seconds in each Video
    depth:int
        Number of encoder layers
    num_heads:int
        Number of attention heads
    head_dim:int
        Dimension of head
    n_classes:int
        Number of classes
    mlp_dim: int
        Dimension of hidden layer
    pool: str
        Pooling operation,must be one of {"cls","mean"},default is "cls"
    p_dropout:float
        Dropout probability
    attn_dropout:float
        Dropout probability
    drop_path_rate:float
        Stochastic drop path rate
    """

    def __init__(
        self,
        img_size,
        in_channels,
        patch_size,
        embedding_dim,
        num_frames,
        depth,
        num_heads,
        head_dim,
        n_classes,
        mlp_dim=None,
        pool="cls",
        p_dropout=0.0,
        attn_dropout=0.0,
        drop_path_rate=0.02,
    ):
        super(ViViTModel2, self).__init__(
            img_size=img_size,
            in_channels=in_channels,
            patch_size=patch_size,
            pool=pool,
        )

        patch_dim = in_channels * patch_size**2
        self.patch_embedding = LinearVideoEmbedding(
            embedding_dim=embedding_dim,
            patch_height=patch_size,
            patch_width=patch_size,
            patch_dim=patch_dim,
        )

        self.pos_embedding = PosEmbedding(
            shape=[num_frames, self.num_patches + 1], dim=embedding_dim, drop=p_dropout
        )

        self.space_token = nn.Parameter(
            torch.randn(1, 1, embedding_dim)
        )  # this is similar to using cls token in vanilla vision transformer
        self.spatial_transformer = VanillaEncoder(
            embedding_dim=embedding_dim,
            depth=depth,
            num_heads=num_heads,
            head_dim=head_dim,
            mlp_dim=mlp_dim,
            p_dropout=p_dropout,
            attn_dropout=attn_dropout,
            drop_path_rate=drop_path_rate,
        )

        self.time_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        self.temporal_transformer = VanillaEncoder(
            embedding_dim=embedding_dim,
            depth=depth,
            num_heads=num_heads,
            head_dim=head_dim,
            mlp_dim=mlp_dim,
            p_dropout=p_dropout,
            attn_dropout=attn_dropout,
            drop_path_rate=drop_path_rate,
        )



    def forward(self, x):
        print("Beforeeee patch emedding", x.shape)

        x = self.patch_embedding(x)
        print("Afterrrrrr patch emedding", x.shape)

        (
            b,
            t,
            n,
            d,
        ) = x.shape  # shape of x will be number of videos,time,num_frames,embedding dim
        print("After reshapig?",x.shape)
        print("space token",self.space_token.shape)
        
        cls_space_tokens = repeat(self.space_token, "() n d -> b t n d", b=b, t=t)
        print("xls_space_topen",cls_space_tokens.shape)

        x = nn.Parameter(torch.cat((cls_space_tokens, x), dim=2))
        x = self.pos_embedding(x)

        x = rearrange(x, "b t n d -> (b t) n d")
        x = self.spatial_transformer(x)
        print("after spaial shape of x",x.shape)
        x = rearrange(x[:, 0], "(b t) ... -> b t ...", b=b)
        print("after rearrange",x.shape)
        print("temporal token",self.time_token.shape)

        cls_temporal_tokens = repeat(self.time_token, "() n d -> b n d", b=b)
        print("cls_temporal", cls_temporal_tokens.shape)
        x = torch.cat((cls_temporal_tokens, x), dim=1)
        print("after cat operation temporal",x.shape)

        x = self.temporal_transformer(x)
        
        print("after temporal", x.shape)

        x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]
        print("final shape", x.shape)

        #x = self.decoder(x)

        return x

In [3]:
class VanillaEncoder(nn.Module):
    """
    Parameters
    ----------
    embedding_dim: int
        Dimension of the embedding
    depth: int
        Number of self-attention layers
    num_heads: int
        Number of the attention heads
    head_dim: int
        Dimension of each head
    mlp_dim: int
        Dimension of the hidden layer in the feed-forward layer
    p_dropout: float
        Dropout Probability
    attn_dropout: float
        Dropout Probability
    drop_path_rate: float
        Stochastic drop path rate
    """

    def __init__(
        self,
        embedding_dim,
        depth,
        num_heads,
        head_dim,
        mlp_dim,
        p_dropout=0.0,
        attn_dropout=0.0,
        drop_path_rate=0.0,
        drop_path_mode="batch",
    ):
        super().__init__()

        self.encoder = nn.ModuleList([])
        for _ in range(depth):
            self.encoder.append(
                nn.ModuleList(
                    [
                        PreNorm(
                            dim=embedding_dim,
                            fn=VanillaSelfAttention(
                                dim=embedding_dim,
                                num_heads=num_heads,
                                head_dim=head_dim,
                                p_dropout=attn_dropout,
                            ),
                        ),
                        PreNorm(
                            dim=embedding_dim,
                            fn=FeedForward(
                                dim=embedding_dim,
                                hidden_dim=mlp_dim,
                                p_dropout=p_dropout,
                            ),
                        ),
                    ]
                )
            )
        self.drop_path = (
            StochasticDepth(p=drop_path_rate, mode=drop_path_mode)
            if drop_path_rate > 0.0
            else nn.Identity()
        )

    def forward(self, x):
        """
        Parameters
        ----------
        x: torch.Tensor
        Returns
        ----------
        torch.Tensor
            Returns output tensor
        """
        for attn, ff in self.encoder:
            x = attn(x) + x
            x = self.drop_path(ff(x)) + x

        return x
class PosEmbedding(nn.Module):
    """
    Generalised Positional Embedding class
    """

    def __init__(self, shape, dim, drop=None, sinusoidal=False, std=0.02):
        super(PosEmbedding, self).__init__()
        print("shape",shape)
        if not sinusoidal:
            if isinstance(shape, int):
                shape = [1, shape, dim]
            else:
                shape = [1] + list(shape) + [dim]
            self.pos_embed = nn.Parameter(torch.zeros(shape))
            print("pos embed shape",self.pos_embed.shape)

        else:
            pe = torch.FloatTensor(
                [
                    [p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                    for p in range(shape)
                ]
            )
            pe[:, 0::2] = torch.sin(pe[:, 0::2])
            pe[:, 1::2] = torch.cos(pe[:, 1::2])
            self.pos_embed = pe
            self.pos_embed.requires_grad = False
        nn.init.trunc_normal_(self.pos_embed, std=std)
        self.pos_drop = nn.Dropout(drop) if drop is not None else nn.Identity()

    def forward(self, x):
        print("before pos embed", x.shape)
        x = x + self.pos_embed
        return self.pos_drop(x)
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange


class LinearVideoEmbedding(nn.Module):
    """
    Parameters
    -----------
    embedding_dim: int
        Dimension of the resultant embedding
    patch_height: int
        Height of the patch
    patch_width: int
        Width of the patch
    patch_dim: int
        patch_dimension
    """

    def __init__(
        self,
        embedding_dim,
        patch_height,
        patch_width,
        patch_dim,
    ):

        super().__init__()
        self.patch_embedding = nn.Sequential(
            Rearrange(
                "b t c (h ph) (w pw) -> b t (h w) (ph pw c)",
                ph=patch_height,
                pw=patch_width,
            ),
            nn.Linear(patch_dim, embedding_dim),
        )

    def forward(self, x):
        """
        Parameters
        -----------
        x: torch.Tensor
            Input tensor
        Returns
        ----------
        torch.Tensor
            Returns patch embeddings of size `embedding_dim`
        """
        print("eforeeeeeeeeeeeeee",x.shape)
        return self.patch_embedding(x)
    
class PreNorm(nn.Module):
    """
    Parameters
    ----------
    dim: int
        Dimension of the embedding
    fn:nn.Module
        Attention class
    context_dim: int
        Dimension of the context array used in cross attention
    """

    def __init__(self, dim, fn, context_dim=None):
        super().__init__()

        self.norm = nn.LayerNorm(dim)
        self.context_norm = (
            nn.LayerNorm(context_dim) if context_dim is not None else None
        )
        self.fn = fn

    def forward(self, x, **kwargs):
        if "context" in kwargs.keys() and kwargs["context"] is not None:
            normed_context = self.context_norm(kwargs["context"])
            kwargs.update(context=normed_context)
        return self.fn(self.norm(x), **kwargs)

class VanillaSelfAttention(nn.Module):
    """
    Vanilla O(:math:`n^2`) Self attention introduced in `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
    Parameters
    -----------
    dim: int
        Dimension of the embedding
    num_heads: int
        Number of the attention heads
    head_dim: int
        Dimension of each head
    p_dropout: float
        Dropout Probability
    """

    def __init__(self, dim, num_heads=8, head_dim=64, p_dropout=0.0):
        super().__init__()

        inner_dim = head_dim * num_heads
        project_out = not (num_heads == 1 and head_dim == dim)

        self.num_heads = num_heads
        self.scale = head_dim**-0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = (
            nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(p_dropout))
            if project_out
            else nn.Identity()
        )

    def forward(self, x):
        """
        Parameters
        ----------
        x: torch.Tensor
            Input tensor
        Returns
        ----------
        torch.Tensor
            Returns output tensor by applying self-attention on input tensor
        """
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(
            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), qkv
        )

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        #out = rearrange(out, "b h n d -> b n (h d)")

        return self.to_out(out)
    
    
class FeedForward(nn.Module):
    """
    Parameters
    ----------
    dim: int
        Dimension of the input tensor
    hidden_dim: int, optional
        Dimension of hidden layer
    out_dim: int, optional
        Dimension of the output tensor
    p_dropout: float
        Dropout probability, default=0.0
    """

    def __init__(self, dim, hidden_dim=None, out_dim=None, p_dropout=0.0):
        super().__init__()

        out_dim = out_dim if out_dim is not None else dim
        hidden_dim = hidden_dim if hidden_dim is not None else dim

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p_dropout),
            nn.Linear(hidden_dim, out_dim),
            nn.Dropout(p_dropout),
        )

    def forward(self, x):
        """
        Parameters
        ----------
        x: torch.Tensor
            Input tensor
        Returns
        ----------
        torch.Tensor
            Returns output tensor by performing linear operations and activation on input tensor
        """

        return self.net(x)

In [4]:
from torchvision.ops import StochasticDepth
import torch
from einops import rearrange, repeat
images = torch.randn(1, 4 , 3, 768, 768)

# Example data
model = ViViTModel2(
        img_size=768,
        patch_size=16,
        num_frames = 4,
        in_channels=3,
        n_classes=10,
        embedding_dim=192,
        depth=2,
        num_heads =1,
        head_dim=3,
    )
logits = model(images)


shape [4, 2305]
pos embed shape torch.Size([1, 4, 2305, 192])
Beforeeee patch emedding torch.Size([1, 4, 3, 768, 768])
eforeeeeeeeeeeeeee torch.Size([1, 4, 3, 768, 768])
Afterrrrrr patch emedding torch.Size([1, 4, 2304, 192])
After reshapig? torch.Size([1, 4, 2304, 192])
space token torch.Size([1, 1, 192])
xls_space_topen torch.Size([1, 4, 1, 192])
before pos embed torch.Size([1, 4, 2305, 192])


EinopsError:  Error while processing rearrange-reduction pattern "b n (h d) -> b h n d".
 Input tensor shape: torch.Size([4, 4, 2305, 3]). Additional info: {'h': 1}.
 Expected 3 dimensions, got 4

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size,
        patch_size,
        num_frames,
        n_layers,
        d_model,
        d_ff,
        n_heads,
        n_cls,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3
    ):

In [None]:
images=torch.rand(4,3,768,768)
model1=VisionTransformer(image_size=(768,768), patch_size=16,n_layers=12,d_model=192,n_cls=10,n_heads=3,d_ff=4*192)
logits1=model1(images)

In [None]:
data = [torch.rand(8,2048,192),torch.rand(8,2048,192)]
data1 = torch.cat(data,dim=0)
data1=data1.reshape(8,2,2048,192)
print(data1[2].shape)
#data1=rearrange(data1, 's b n w h -> b s n w h')
#data1=data1.permute(1,0,2,3)






### 

In [7]:
images =  torch.rand(1,4,3,768,768)
num_seqs = images.shape[1]
images=images.reshape(4,1,3,768,768)
print(images[3].shape)
vit_encoder =  VisionTransformer(image_size=(768,768), patch_size=16,n_layers=12,d_model=192,n_cls=10,n_heads=3,d_ff=4*192)
vit_encoderoutput=[]
for i in range(num_seqs):
    x=vit_encoder(images[i])
    vit_encoderoutput.append(x)
    
vit_embds = torch.stack(vit_encoderoutput)

vit_embds = vit_embds.reshape(1,4,2305, 192)
print("shape of vitembds",vit_embds[:, 0].shape)

for idx, emds in enumerate(vit_embds):
    print(emds.shape)
temporal_token = nn.Parameter(torch.randn(1,1,2305,192))

#vit_embds = vit_embds.reshape(4,2305,192)
#vit_embds = rearrange(vit_embds[:, 0], "(b t) ... -> b t ...", b=1)
print("after rearrange",vit_embds.shape)
print("temporal token",temporal_token.shape)

#cls_temporal_tokens = repeat(temporal_token, "b t n d -> b t n d", b=1,t=4)
#print("cls_temporal", cls_temporal_tokens.shape)
vit_embds = torch.cat((temporal_token, vit_embds), dim=1)
print("after cat operation temporal",vit_embds.shape)
vit_embds = vit_embds.reshape(1,5*2305,192)
print("newww", vit_embds.shape)

dpr = [x.item() for x in torch.linspace(0, 0.0, 12)]
blocks = nn.ModuleList([Block(192, 3, 4*192, 0.1, dpr[i]) for i in range(12)])
#x = self.temporal_transformer(x)
        
for blk in blocks:
    vit_embds = blk(vit_embds)
    
print(vit_embds.shape)


#vit_embds = temporal_transformer(vit_embds)
        

torch.Size([1, 3, 768, 768])
(768, 768)
Patch embedding embed dim 192
Vision transformer im shape torch.Size([1, 3, 768, 768])
Patch Embedding im shape torch.Size([1, 3, 768, 768])
Patch embedding After proj operation x shape torch.Size([1, 2304, 192])
ViT after patch embed operation torch.Size([1, 2304, 192])
ViT cls tokens torch.Size([1, 1, 192])
ViT after cat operatiom torch.Size([1, 2305, 192])
ViT pos_embed torch.Size([1, 2305, 192])
ViT after adding pos embedding torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
torch.Size([1, 2305, 192])
ViT Last step of encoder torch.Size([1, 2305, 192])
Vision transformer im shape torch.Size([1, 3, 768, 768])
Patch Embedding im shape torch.Size([1, 3, 768, 768])
Patch embedd

# 