In [None]:
!pip install timm

In [30]:
import numpy as np
import timm
import torch
from custom import VisionTransformer

# Function to calculate the number of learnable (requires_grad=True) parameters in a module
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

# Function to assert that two tensors are equal within a given tolerance
def assert_tensors_equal(t1, t2, rtol=1e-05, atol=3e-06):
    # Detach tensors from computation graph and convert them to numpy arrays
    a1, a2 = t1.detach().numpy(), t2.detach().numpy()
    # Assert that all elements are approximately equal within the given relative and absolute tolerances
    np.testing.assert_allclose(a1, a2, rtol=rtol, atol=atol)

In [31]:
# Define the name of the model to create
model_name = "vit_base_patch16_384"

# Create the official model using timm's create_model function with pretrained weights
model_official = timm.create_model(model_name, pretrained=True)

# Print the total number of parameters in the official model
print("Total number of parameters in official model:", get_n_params(model_official))

# Set the official model to evaluation mode
model_official.eval()

# Print the type of the official model
print("Type of model_official:", type(model_official))

Total number of parameters in official model: 86859496
Type of model_official: <class 'timm.models.vision_transformer.VisionTransformer'>


In [32]:
# Configuration dictionary for custom model
custom_config = {
    "img_size": 384,
    "in_chans": 3,
    "patch_size": 16,
    "embed_dim": 768,
    "depth": 12,
    "n_heads": 12,
    "qkv_bias": True,
    "mlp_ratio": 4,
}

# Create the custom VisionTransformer model using the provided configuration
model_custom = VisionTransformer(**custom_config)

# Print the total number of parameters in the custom model
print("Total number of parameters in custom model:", get_n_params(model_custom))

# Set the custom model to evaluation mode
model_custom.eval()

Total number of parameters in custom model: 86859496


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0, inplace=False)
        (proj_drop): Dropout(p=0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=1000, bias=True)

In [33]:
# Iterate through named parameters of both official and custom models
for (n_o, p_o), (n_c, p_c) in zip(model_official.named_parameters(), model_custom.named_parameters()):
    # Assert that the number of elements in both parameters are equal
    assert p_o.numel() == p_c.numel()

    # Print the names of the corresponding parameters
    print(f"{n_o} | {n_c}")

    # Copy the data from official model's parameter to custom model's parameter
    p_c.data[:] = p_o.data

    # Assert that the tensors of both parameters are equal using a custom function assert_tensors_equal
    assert_tensors_equal(p_c.data, p_o.data)


cls_token | cls_token
pos_embed | pos_embed
patch_embed.proj.weight | patch_embed.proj.weight
patch_embed.proj.bias | patch_embed.proj.bias
blocks.0.norm1.weight | blocks.0.norm1.weight
blocks.0.norm1.bias | blocks.0.norm1.bias
blocks.0.attn.qkv.weight | blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias | blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight | blocks.0.attn.proj.weight
blocks.0.attn.proj.bias | blocks.0.attn.proj.bias
blocks.0.norm2.weight | blocks.0.norm2.weight
blocks.0.norm2.bias | blocks.0.norm2.bias
blocks.0.mlp.fc1.weight | blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias | blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight | blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias | blocks.0.mlp.fc2.bias
blocks.1.norm1.weight | blocks.1.norm1.weight
blocks.1.norm1.bias | blocks.1.norm1.bias
blocks.1.attn.qkv.weight | blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias | blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight | blocks.1.attn.proj.weight
blocks.1.attn.proj.bias | blocks.1.attn.proj.b

In [34]:
inp = torch.rand(1, 3, 384, 384)
res_c = model_custom(inp)
res_o = model_official(inp)

# Asserts
assert get_n_params(model_custom) == get_n_params(model_official)
assert_tensors_equal(res_c, res_o)

# Save custom model
torch.save(model_custom, "model.pth")

In [35]:
import numpy as np
from PIL import Image
import torch

k = 10

imagenet_labels = dict(enumerate(open("classes.txt")))

model = torch.load("model.pth")
model.eval()

img = (np.array(Image.open("cat.png")) / 128) - 1  # in the range -1, 1
inp = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(torch.float32)
logits = model(inp)
probs = torch.nn.functional.softmax(logits, dim=-1)

top_probs, top_ixs = probs[0].topk(k)

for i, (ix_, prob_) in enumerate(zip(top_ixs, top_probs)):
    ix = ix_.item()
    prob = prob_.item()
    cls = imagenet_labels[ix].strip()
    print(f"{i}: {cls:<45} --- {prob:.4f}")

0: tabby, tabby_cat                              --- 0.8001
1: tiger_cat                                     --- 0.1752
2: Egyptian_cat                                  --- 0.0172
3: lynx, catamount                               --- 0.0018
4: Persian_cat                                   --- 0.0011
5: Siamese_cat, Siamese                          --- 0.0002
6: bow_tie, bow-tie, bowtie                      --- 0.0002
7: weasel                                        --- 0.0001
8: lens_cap, lens_cover                          --- 0.0001
9: remote_control, remote                        --- 0.0001
