In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt
from collections import OrderedDict
from torchsummary import summary

class PatchEmbedding(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
    def forward(self, x):

        x = self.proj(x) 
        x = x.flatten(2)  
        x = x.transpose(1, 2)  
        return x
    
    def extra_repr(self):
        return (f"img_size={self.img_size}, patch_size={self.patch_size}, "
                f"in_channels={self.proj.in_channels}, embed_dim={self.proj.out_channels}")

class ReducedLinearAttention(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., reduction_ratio=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.reduced_dim = dim // reduction_ratio
        self.head_dim = self.reduced_dim // num_heads
        self.scale = 1.0 / sqrt(self.head_dim)
        
        self.q = nn.Linear(dim, self.reduced_dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, self.reduced_dim * 2, bias=qkv_bias)  
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.reduced_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        
        # Project queries
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Project shared keys and values
        kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]  
        
        # Linear attention with exponential kernel
        k = k * self.scale
        attn = torch.exp(q @ k.transpose(-2, -1))
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.attn_drop(attn)
        
        x = torch.matmul(attn, v) 
        x = x.transpose(1, 2).reshape(B, N, self.reduced_dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
    def extra_repr(self):
        return (f"dim={self.dim}, num_heads={self.num_heads}, reduced_dim={self.reduced_dim}, "
                f"head_dim={self.head_dim}, reduction_ratio={self.dim//self.reduced_dim}")

class MLP(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
    
    def extra_repr(self):
        return (f"in_features={self.fc1.in_features}, "
                f"hidden_features={self.fc1.out_features}, "
                f"out_features={self.fc2.out_features}")

class LinearTransformerBlock(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., reduction_ratio=2):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ReducedLinearAttention(
            dim, 
            num_heads=num_heads, 
            qkv_bias=qkv_bias, 
            attn_drop=attn_drop, 
            proj_drop=drop,
            reduction_ratio=reduction_ratio
        )
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            drop=drop
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
    
    def extra_repr(self):
        return (f"dim={self.attn.dim}, num_heads={self.attn.num_heads}, "
                f"mlp_ratio={self.mlp.fc1.out_features/self.attn.dim:.1f}, "
                f"reduction_ratio={self.attn.dim//self.attn.reduced_dim}")

class LinearViT(nn.Module):

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_rate=0.,
        attn_drop_rate=0.,
        reduction_ratio=2
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.reduction_ratio = reduction_ratio
        
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        self.blocks = nn.ModuleList([
            LinearTransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                reduction_ratio=reduction_ratio
            )
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        B = x.shape[0]
        
        x = self.patch_embed(x) 
        
        cls_token = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_token, x), dim=1)
        
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        cls_token_final = x[:, 0] 
        x = self.head(cls_token_final)
        
        return x
    
    def get_layer_specs(self):
        """Returns a detailed summary of each layer's specifications"""
        specs = OrderedDict()
        
        specs["Input"] = {
            "shape": f"(batch_size, {self.in_channels}, {self.img_size}, {self.img_size})",
            "params": 0,
            "details": "Input image"
        }
        
        patch_params = sum(p.numel() for p in self.patch_embed.parameters())
        specs["PatchEmbedding"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches}, {self.embed_dim})",
            "params": patch_params,
            "details": self.patch_embed.extra_repr()
        }
        
        pos_params = self.pos_embed.numel() + self.cls_token.numel()
        specs["PositionEmbedding"] = {
            "shape": f"(1, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": pos_params,
            "details": "Learned positional embeddings + class token"
        }
        
        specs["PosDrop"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": 0,
            "details": f"Dropout(p={self.pos_drop.p})"
        }
        
        for i, block in enumerate(self.blocks):
            block_params = sum(p.numel() for p in block.parameters())
            specs[f"LinearTransformerBlock_{i+1}"] = {
                "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
                "params": block_params,
                "details": block.extra_repr()
            }
        
        norm_params = sum(p.numel() for p in self.norm.parameters())
        specs["LayerNorm"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": norm_params,
            "details": f"LayerNorm({self.embed_dim})"
        }
        
        # Classification head
        head_params = sum(p.numel() for p in self.head.parameters()) if self.num_classes > 0 else 0
        specs["Head"] = {
            "shape": f"(batch_size, {self.num_classes})",
            "params": head_params,
            "details": f"Linear({self.embed_dim}->{self.num_classes})" if self.num_classes > 0 else "Identity"
        }
        
        return specs

def print_model_summary(model):
    """Prints a detailed summary of the model architecture"""
    print("=" * 80)
    print(f"{'Linear Vision Transformer (ViT) Summary':^80}")
    print("=" * 80)
    print(f"Image size: {model.img_size}x{model.img_size}")
    print(f"Patch size: {model.patch_size}x{model.patch_size}")
    print(f"Input channels: {model.in_channels}")
    print(f"Embedding dimension: {model.embed_dim}")
    print(f"Number of heads: {model.num_heads}")
    print(f"Number of layers: {len(model.blocks)}")
    print(f"MLP ratio: {model.blocks[0].mlp.fc1.out_features/model.embed_dim:.1f}")
    print(f"Attention reduction ratio: {model.reduction_ratio}")
    print(f"Number of classes: {model.num_classes}")
    print("=" * 80)
    print("\nLayer Details:\n")
    
    specs = model.get_layer_specs()
    max_layer_len = max(len(name) for name in specs.keys())
    max_shape_len = max(len(spec["shape"]) for spec in specs.values())
    max_params_len = max(len(f"{spec['params']:,}") if isinstance(spec["params"], int) 
                        else len(spec["params"]) for spec in specs.values())
    
    total_params = 0
    
    for name, spec in specs.items():
        params = spec["params"]
        if isinstance(params, int):
            total_params += params
            params_str = f"{params:,}"
        else:
            params_str = params
            
        print(f"{name:<{max_layer_len}} | Shape: {spec['shape']:<{max_shape_len}} | "
              f"Params: {params_str:<{max_params_len}} | Details: {spec['details']}")
    
    print("\n" + "=" * 80)
    print(f"{'Total Parameters':<{max_layer_len}} : {total_params:,}")
    print("=" * 80)

if __name__ == "__main__":
    config = {
        "img_size": 224,
        "patch_size": 16,
        "in_channels": 3,
        "num_classes": 2,
        "embed_dim": 768,
        "depth": 8,
        "num_heads": 8,
        "mlp_ratio": 4.0,
        "qkv_bias": True,
        "drop_rate": 0.1,
        "attn_drop_rate": 0.1,
        "reduction_ratio": 4
    }
    
    model = LinearViT(**config)
    
    print_model_summary(model)
    
    print("\nStandard Torch Summary:")
    summary(model, (3, 224, 224), device="cpu")

                    Linear Vision Transformer (ViT) Summary                     
Image size: 224x224
Patch size: 16x16
Input channels: 3
Embedding dimension: 768
Number of heads: 8
Number of layers: 8
MLP ratio: 4.0
Attention reduction ratio: 4
Number of classes: 2

Layer Details:

Input                    | Shape: (batch_size, 3, 224, 224) | Params: 0         | Details: Input image
PatchEmbedding           | Shape: (batch_size, 196, 768)    | Params: 590,592   | Details: img_size=224, patch_size=16, in_channels=3, embed_dim=768
PositionEmbedding        | Shape: (1, 197, 768)             | Params: 152,064   | Details: Learned positional embeddings + class token
PosDrop                  | Shape: (batch_size, 197, 768)    | Params: 0         | Details: Dropout(p=0.1)
LinearTransformerBlock_1 | Shape: (batch_size, 197, 768)    | Params: 5,316,672 | Details: dim=768, num_heads=8, mlp_ratio=4.0, reduction_ratio=4
LinearTransformerBlock_2 | Shape: (batch_size, 197, 768)    | Params: 5,316,67

In [3]:
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True)
print(f"FLOPs: {macs}, Parameters: {params}")

LinearViT(
  43.13 M, 99.649% Params, 8.5 GMac, 99.247% MACs, 
  (patch_embed): PatchEmbedding(
    590.59 k, 1.365% Params, 115.76 MMac, 1.352% MACs, img_size=224, patch_size=16, in_channels=3, embed_dim=768
    (proj): Conv2d(590.59 k, 1.365% Params, 115.76 MMac, 1.352% MACs, 3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-7): 8 x LinearTransformerBlock(
      5.32 M, 12.285% Params, 1.05 GMac, 12.237% MACs, dim=768, num_heads=8, mlp_ratio=4.0, reduction_ratio=4
      (norm1): LayerNorm(1.54 k, 0.004% Params, 151.3 KMac, 0.002% MACs, (768,), eps=1e-05, elementwise_affine=True)
      (attn): ReducedLinearAttention(
        591.17 k, 1.366% Params, 116.46 MMac, 1.360% MACs, dim=768, num_heads=8, reduced_dim=192, head_dim=24, reduction_ratio=4
        (q): Linear(147.65 k, 0.341% Params, 29.09 MMac, 0.340% MACs, in_features=768, out_features=192, bias=True)
        (kv):