In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt
from collections import OrderedDict
from torchsummary import summary

class PatchEmbedding(nn.Module):
    """
    Split image into patches and project them into embedding space.
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
    def forward(self, x):
        """
        Input shape: (batch_size, channels, height, width)
        Output shape: (batch_size, n_patches, embed_dim)
        """
        x = self.proj(x)  # (batch_size, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x
    
    def extra_repr(self):
        return (f"img_size={self.img_size}, patch_size={self.patch_size}, "
                f"in_channels={self.proj.in_channels}, embed_dim={self.proj.out_channels}")

class Attention(nn.Module):
    """
    Scaled Dot-Product Attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each of shape (B, num_heads, N, head_dim)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
    def extra_repr(self):
        return f"dim={self.dim}, num_heads={self.num_heads}, head_dim={self.head_dim}"

class MLP(nn.Module):
    """
    Multilayer Perceptron with GELU activation and dropout.
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
    
    def extra_repr(self):
        return (f"in_features={self.fc1.in_features}, "
                f"hidden_features={self.fc1.out_features}, "
                f"out_features={self.fc2.out_features}")

class TransformerBlock(nn.Module):
    """
    Transformer block with attention and residual connections.
    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
        )
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            drop=drop
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
    
    def extra_repr(self):
        return (f"dim={self.attn.dim}, num_heads={self.attn.num_heads}, "
                f"mlp_ratio={self.mlp.fc1.out_features/self.attn.dim:.1f}")

class VisionTransformer(nn.Module):
    """
    Standard Vision Transformer (ViT) implementation.
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_rate=0.,
        attn_drop_rate=0.,
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )
        
        # Class token and position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        
        # Dropout after positional embedding
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
            )
            for _ in range(depth)
        ])
        
        # Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
        # Initialize weights
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)
        
        # Add class token
        cls_token = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (B, n_patches+1, embed_dim)
        
        # Add position embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Classification head
        x = self.norm(x)
        cls_token_final = x[:, 0]  # Take only the class token
        x = self.head(cls_token_final)
        
        return x
    
    def get_layer_specs(self):
        """Returns a detailed summary of each layer's specifications"""
        specs = OrderedDict()
        
        # Input specs
        specs["Input"] = {
            "shape": f"(batch_size, {self.in_channels}, {self.img_size}, {self.img_size})",
            "params": 0,
            "details": "Input image"
        }
        
        # Patch embedding
        patch_params = sum(p.numel() for p in self.patch_embed.parameters())
        specs["PatchEmbedding"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches}, {self.embed_dim})",
            "params": patch_params,
            "details": self.patch_embed.extra_repr()
        }
        
        # Class token and position embedding
        pos_params = self.pos_embed.numel() + self.cls_token.numel()
        specs["PositionEmbedding"] = {
            "shape": f"(1, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": pos_params,
            "details": "Learned positional embeddings + class token"
        }
        
        # Positional dropout
        specs["PosDrop"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": 0,
            "details": f"Dropout(p={self.pos_drop.p})"
        }
        
        # Transformer blocks
        for i, block in enumerate(self.blocks):
            block_params = sum(p.numel() for p in block.parameters())
            specs[f"TransformerBlock_{i+1}"] = {
                "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
                "params": block_params,
                "details": block.extra_repr()
            }
        
        # Final normalization
        norm_params = sum(p.numel() for p in self.norm.parameters())
        specs["LayerNorm"] = {
            "shape": f"(batch_size, {self.patch_embed.n_patches + 1}, {self.embed_dim})",
            "params": norm_params,
            "details": f"LayerNorm({self.embed_dim})"
        }
        
        # Classification head
        head_params = sum(p.numel() for p in self.head.parameters()) if self.num_classes > 0 else 0
        specs["Head"] = {
            "shape": f"(batch_size, {self.num_classes})",
            "params": head_params,
            "details": f"Linear({self.embed_dim}->{self.num_classes})" if self.num_classes > 0 else "Identity"
        }
        
        return specs

def print_model_summary(model):
    """Prints a detailed summary of the model architecture"""
    print("=" * 80)
    print(f"{'Vision Transformer (ViT) Summary':^80}")
    print("=" * 80)
    print(f"Image size: {model.img_size}x{model.img_size}")
    print(f"Patch size: {model.patch_size}x{model.patch_size}")
    print(f"Input channels: {model.in_channels}")
    print(f"Embedding dimension: {model.embed_dim}")
    print(f"Number of heads: {model.num_heads}")
    print(f"Number of layers: {model.depth}")
    print(f"MLP ratio: {model.blocks[0].mlp.fc1.out_features/model.embed_dim:.1f}")
    print(f"Number of classes: {model.num_classes}")
    print("=" * 80)
    print("\nLayer Details:\n")
    
    specs = model.get_layer_specs()
    max_layer_len = max(len(name) for name in specs.keys())
    max_shape_len = max(len(spec["shape"]) for spec in specs.values())
    max_params_len = max(len(f"{spec['params']:,}") if isinstance(spec["params"], int) 
                        else len(spec["params"]) for spec in specs.values())
    
    total_params = 0
    
    for name, spec in specs.items():
        params = spec["params"]
        if isinstance(params, int):
            total_params += params
            params_str = f"{params:,}"
        else:
            params_str = params
            
        print(f"{name:<{max_layer_len}} | Shape: {spec['shape']:<{max_shape_len}} | "
              f"Params: {params_str:<{max_params_len}} | Details: {spec['details']}")
    
    print("\n" + "=" * 80)
    print(f"{'Total Parameters':<{max_layer_len}} : {total_params:,}")
    print("=" * 80)

# Create and test the model
if __name__ == "__main__":
    # Configuration (same as original)
    config = {
        "img_size": 224,
        "patch_size": 16,
        "in_channels": 3,
        "num_classes": 2,
        "embed_dim": 768,
        "depth": 8,
        "num_heads": 8,
        "mlp_ratio": 4.0,
        "qkv_bias": True,
        "drop_rate": 0.1,
        "attn_drop_rate": 0.1
    }
    
    # Create model
    model = VisionTransformer(**config)
    
    # Print detailed summary
    print_model_summary(model)
    
    # Print standard torchsummary
    print("\nStandard Torch Summary:")
    summary(model, (3, 224, 224), device="cpu")

                        Vision Transformer (ViT) Summary                        
Image size: 224x224
Patch size: 16x16
Input channels: 3
Embedding dimension: 768
Number of heads: 8
Number of layers: 8
MLP ratio: 4.0
Number of classes: 2

Layer Details:

Input              | Shape: (batch_size, 3, 224, 224) | Params: 0         | Details: Input image
PatchEmbedding     | Shape: (batch_size, 196, 768)    | Params: 590,592   | Details: img_size=224, patch_size=16, in_channels=3, embed_dim=768
PositionEmbedding  | Shape: (1, 197, 768)             | Params: 152,064   | Details: Learned positional embeddings + class token
PosDrop            | Shape: (batch_size, 197, 768)    | Params: 0         | Details: Dropout(p=0.1)
TransformerBlock_1 | Shape: (batch_size, 197, 768)    | Params: 7,087,872 | Details: dim=768, num_heads=8, mlp_ratio=4.0
TransformerBlock_2 | Shape: (batch_size, 197, 768)    | Params: 7,087,872 | Details: dim=768, num_heads=8, mlp_ratio=4.0
TransformerBlock_3 | Shape: (batch_