In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

# All classes are refactored to match timm's naming conventions.

class PatchEmbeddings(nn.Module):
    """
    Splits image into patches and embeds them.
    This class now matches the 'patch_embed' module in timm's ViT.
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embedding_dim=192):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        # Changed 'patching' to 'proj' to match timm
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embedding_dim,
            kernel_size=patch_size, # kernel_size should be patch_size
            stride=patch_size,
            padding=0
        )
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self, x):
        x = self.proj(x)  # (B, C, H, W) -> (B, embedding_dim, H', W')
        x = self.flatten(x) # (B, embedding_dim, H', W') -> (B, embedding_dim, num_patches)
        x = x.permute(0, 2, 1) # (B, embedding_dim, num_patches) -> (B, num_patches, embedding_dim)
        return x

class MultiheadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention module -- MODIFIED to match timm's architecture.
    This version uses a single Linear layer for Q, K, V for efficiency and
    to allow direct weight loading from timm.
    """
    def __init__(self, embedding_dim=192, num_heads=3, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # MODIFICATION: A single Linear layer for Q, K, and V
        # The output dimension is 3 * embedding_dim because it holds Q, K, and V concatenated.
        self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=True)

        # The projection layer remains the same
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, n_patches, dim = x.shape

        # MODIFICATION: Project x once to get q, k, and v together
        # (B, N, D) -> (B, N, 3*D)
        qkv = self.qkv(x)

        # Reshape and split qkv into q, k, and v for multi-head attention
        # (B, N, 3*D) -> (B, N, 3, num_heads, head_dim)
        qkv = qkv.reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim)
        
        # Permute to (3, B, num_heads, N, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        # Split into q, k, v. Each will have shape (B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled Dot-Product Attention
        # (B, num_heads, N, head_dim) @ (B, num_heads, head_dim, N) -> (B, num_heads, N, N)
        attn_score = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn_score.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # (B, num_heads, N, N) @ (B, num_heads, N, head_dim) -> (B, num_heads, N, head_dim)
        x = (attn @ v).transpose(1, 2).reshape(batch_size, n_patches, dim)
        
        # Final projection
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    """
    MLP block. Refactored to use named 'fc1' and 'fc2' like timm.
    """
    def __init__(self, embedding_dim=192, mlp_ratio=4, dropout=0.0):
        super().__init__()
        mlp_hidden_dim = int(embedding_dim * mlp_ratio)
        self.fc1 = nn.Linear(embedding_dim, mlp_hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(mlp_hidden_dim, embedding_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
    """
    A single Transformer Encoder block.
    This structure now directly mirrors timm's 'blocks.i' module.
    It contains norm1, attn, norm2, and mlp.
    """
    def __init__(self, embedding_dim=192, num_heads=3, mlp_ratio=4., dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.attn = MultiheadSelfAttention(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            dropout=dropout
        )
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.mlp = MLP(
            embedding_dim=embedding_dim,
            mlp_ratio=mlp_ratio,
            dropout=dropout
        )

    def forward(self, x):
        # Pre-Norm architecture
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VitTiny(nn.Module):
    """
    The main Vision Transformer class.
    All module names are now aligned with the timm library.
    """
    def __init__(self, 
                 img_size=224,
                 in_channels=3,
                 patch_size=16,
                 num_transformer_layers=12,
                 embedding_dim=192,
                 mlp_ratio=4,
                 num_heads=3,
                 num_classes=1000, # Default for ImageNet-1k
                 dropout=0.0):
        super().__init__()

        # Renamed 'patch_embed_layer' to 'patch_embed'
        self.patch_embed = PatchEmbeddings(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embedding_dim=embedding_dim
        )
        
        # Renamed 'cls' to 'cls_token'
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # Renamed 'pos_emb' to 'pos_embed'
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Renamed 'enc_blocks' to 'blocks'
        self.blocks = nn.Sequential(*[
            TransformerEncoderBlock(
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            ) for _ in range(num_transformer_layers)])
        
        # Final LayerNorm and classifier head, matching timm's naming
        self.norm = nn.LayerNorm(embedding_dim)
        self.head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embed(x)
        
        # Prepend class token
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        x = self.blocks(x)
        
        x = self.norm(x)
        
        # Classifier head only uses the class token
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        
        return x

In [1]:
from model import ViT

In [2]:
import json

def read_config(config_file_path: str):
    try:
        with open(config_file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    except:
        return FileNotFoundError("The config file is corrupted/absent")
    
configs = read_config(r"C:\Users\e87299\Desktop\Training\Week2-Pytorch\Framework-pytorch\configs\configs_classification.json")

In [3]:
import timm
import torch

my_vit_perfect_replica = ViT(config=configs["model"])

# my_vit_perfect_replica = VitTiny(num_classes=1000)


# 2. Create the timm model to get its state dictionary
print("Loading timm model state_dict...")
timm_model = timm.create_model(
    'vit_tiny_patch16_224.augreg_in21k_ft_in1k',
    pretrained=True,
    num_classes=2
)
timm_state_dict = timm_model.state_dict()

# 3. Load the weights directly!
# This will now work because every key and tensor shape matches perfectly.
print("Loading state_dict into custom model...")
my_vit_perfect_replica.load_state_dict(timm_state_dict, strict=True)

print("\nWeights loaded successfully using direct method!")

# Verify the model works
my_vit_perfect_replica.eval()
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = my_vit_perfect_replica(dummy_input)

print(f"Test forward pass successful. Output shape: {output.shape}")

  from .autonotebook import tqdm as notebook_tqdm


Loading timm model state_dict...
Loading state_dict into custom model...

Weights loaded successfully using direct method!
Test forward pass successful. Output shape: torch.Size([1, 2])
