# Vision Transformer
## Imports and helpers

In [2]:
import numpy as np
import timm
import torch
from vit.model.vit import VisionTransformer


def parameter_count(module: torch.nn.Module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


def assert_close(t1: torch.Tensor, t2: torch.Tensor):
    a1, a2 = t1.detach().numpy(), t2.detach().numpy()
    np.testing.assert_allclose(a1, a2)

## Official Model

In [3]:
timm.list_models(pretrained=True)

['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn

In [4]:
model_name = 'vit_base_patch16_384'
model_official: torch.nn.Module = timm.create_model(model_name, pretrained=True)
model_official.eval()
print(type(model_official))

URLError: <urlopen error [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond>

## Custom Model

In [None]:
model_custom: torch.nn.Module = VisionTransformer(image_size=384,
                                                  in_channels=3,
                                                  patch_size=16,
                                                  embed_dim=768,
                                                  depth=12,
                                                  n_heads=12,
                                                  qkv_bias=True,
                                                  mlp_ratio=4)

model_custom.eval()

In [None]:
for (n_o, p_o), (n_c, p_c) in zip(model_official.named_parameters(),
                                  model_custom.named_parameters()):
    assert p_o.numel() == p_c.numel()
    print(f"{n_o}|{n_c}")

    p_c.data = p_o.data


In [None]:
in_p = torch.rand(1, 3, 384, 384)
res_c = model_custom(in_p)
res_o = model_official(in_p)

assert parameter_count(model_custom) == parameter_count(model_official)
assert_close(res_o, res_c)

In [None]:
torch.save(model_custom, 'model.pth')