In [39]:
from model.configuration_vit_mae import ViTMAEConfig
from model.mae import ViTMAEEmbeddings, ViTMAEPatchEmbeddings

config = ViTMAEConfig()
patch_embeddings = ViTMAEPatchEmbeddings(config)
vit_embeddings = ViTMAEEmbeddings(config)
config, patch_embeddings, vit_embeddings

(ViTMAEConfig {
   "attention_probs_dropout_prob": 0.0,
   "decoder_hidden_size": 512,
   "decoder_intermediate_size": 2048,
   "decoder_num_attention_heads": 16,
   "decoder_num_hidden_layers": 8,
   "hidden_act": "gelu",
   "hidden_dropout_prob": 0.0,
   "hidden_size": 768,
   "image_size": 224,
   "initializer_range": 0.02,
   "intermediate_size": 3072,
   "layer_norm_eps": 1e-12,
   "mask_ratio": 0.75,
   "model_type": "vit_mae",
   "norm_pix_loss": false,
   "num_attention_heads": 12,
   "num_channels": 3,
   "num_hidden_layers": 12,
   "patch_size": 16,
   "qkv_bias": true,
   "transformers_version": "4.22.1"
 },
 ViTMAEPatchEmbeddings(
   (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
 ),
 ViTMAEEmbeddings(
   (patch_embeddings): ViTMAEPatchEmbeddings(
     (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
   )
 ))

In [40]:
import torch

pixel_values = torch.rand(2, 3, 224, 224)
pixel_values, pixel_values.shape

(tensor([[[[0.4951, 0.0721, 0.4896,  ..., 0.2390, 0.9004, 0.4541],
           [0.7855, 0.5213, 0.3566,  ..., 0.8841, 0.0404, 0.3115],
           [0.6544, 0.8796, 0.6899,  ..., 0.8787, 0.1300, 0.7642],
           ...,
           [0.4336, 0.1213, 0.7397,  ..., 0.0684, 0.6291, 0.8053],
           [0.6662, 0.8739, 0.1029,  ..., 0.8111, 0.3116, 0.9347],
           [0.9309, 0.7031, 0.0873,  ..., 0.1985, 0.1508, 0.3664]],
 
          [[0.7349, 0.9654, 0.3532,  ..., 0.8188, 0.3163, 0.7806],
           [0.4397, 0.2318, 0.7431,  ..., 0.9633, 0.8252, 0.9041],
           [0.1053, 0.0518, 0.4995,  ..., 0.3392, 0.3501, 0.3687],
           ...,
           [0.6691, 0.0453, 0.1278,  ..., 0.6316, 0.6591, 0.2662],
           [0.9109, 0.7459, 0.5275,  ..., 0.8578, 0.6595, 0.7037],
           [0.2074, 0.3899, 0.6594,  ..., 0.3186, 0.3859, 0.8327]],
 
          [[0.2771, 0.5120, 0.7551,  ..., 0.4928, 0.5669, 0.4894],
           [0.1395, 0.7947, 0.4307,  ..., 0.7243, 0.5125, 0.1282],
           [0.3015, 0.57

In [41]:
emb_patch = patch_embeddings(pixel_values)
emb_patch, emb_patch.shape

(tensor([[[ 0.1511,  0.3573, -0.0502,  ..., -0.2963,  0.0760, -0.4688],
          [ 0.1600,  0.3279,  0.1730,  ..., -0.1618,  0.2018, -0.2984],
          [ 0.2183,  0.2686,  0.3596,  ..., -0.0807,  0.0930, -0.4436],
          ...,
          [-0.0231, -0.0600,  0.0582,  ..., -0.2219,  0.4858, -0.3078],
          [ 0.1643,  0.2149,  0.3243,  ...,  0.0296,  0.4256, -0.0761],
          [ 0.2710,  0.1624,  0.3388,  ..., -0.1865,  0.3348, -0.3663]],
 
         [[ 0.3880,  0.4705,  0.2075,  ..., -0.2487,  0.3270, -0.4334],
          [ 0.1031,  0.1022,  0.2230,  ..., -0.0991,  0.1923, -0.4863],
          [ 0.1071,  0.3450, -0.0761,  ..., -0.2565,  0.4488, -0.7837],
          ...,
          [ 0.3444,  0.3786, -0.0654,  ..., -0.0384,  0.3723, -0.3631],
          [ 0.1927,  0.5171, -0.1595,  ...,  0.0655,  0.5558, -0.6155],
          [ 0.2298,  0.3001,  0.2935,  ..., -0.0123,  0.1376, -0.2856]]],
        grad_fn=<TransposeBackward0>),
 torch.Size([2, 196, 768]))

In [43]:
vit_embeddings.position_embeddings, vit_embeddings.position_embeddings.shape

(Parameter containing:
 tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.8415,  0.8153,  0.7886,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-1.0000, -0.8724, -0.5387,  ...,  1.0000,  1.0000,  1.0000],
          [-0.5366, -0.9037, -0.9956,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.4202, -0.1744, -0.6858,  ...,  1.0000,  1.0000,  1.0000]]]),
 torch.Size([1, 197, 768]))

In [45]:
emb_pos = emb_patch + vit_embeddings.position_embeddings[:, 1:, :]
emb_pos, emb_pos.shape

(tensor([[[ 0.1511,  0.3573, -0.0502,  ...,  0.7037,  1.0760,  0.5312],
          [ 1.0015,  1.1432,  0.9616,  ...,  0.8382,  1.2018,  0.7016],
          [ 1.1276,  1.2128,  1.3295,  ...,  0.9193,  1.0930,  0.5564],
          ...,
          [-1.0231, -0.9323, -0.4805,  ...,  0.7781,  1.4858,  0.6922],
          [-0.3722, -0.6888, -0.6713,  ...,  1.0296,  1.4256,  0.9239],
          [ 0.6912, -0.0120, -0.3470,  ...,  0.8135,  1.3348,  0.6337]],
 
         [[ 0.3880,  0.4705,  0.2075,  ...,  0.7513,  1.3270,  0.5666],
          [ 0.9445,  0.9175,  1.0116,  ...,  0.9009,  1.1923,  0.5137],
          [ 1.0164,  1.2892,  0.8938,  ...,  0.7435,  1.4488,  0.2163],
          ...,
          [-0.6556, -0.4938, -0.6042,  ...,  0.9616,  1.3723,  0.6369],
          [-0.3438, -0.3866, -1.1552,  ...,  1.0655,  1.5558,  0.3845],
          [ 0.6499,  0.1257, -0.3923,  ...,  0.9877,  1.1376,  0.7144]]],
        grad_fn=<AddBackward0>),
 torch.Size([2, 196, 768]))

In [49]:
embeddings, mask, ids_restore = vit_embeddings.random_masking(emb_pos)
embeddings, embeddings.shape, mask, mask.shape, ids_restore, ids_restore.shape

(tensor([[[ 0.5278,  0.2685, -0.6473,  ...,  0.8113,  1.0937,  0.7003],
          [-0.5090, -0.2033, -0.3931,  ...,  0.9461,  1.2617,  0.8079],
          [ 0.5182,  0.7603,  0.9714,  ...,  0.9044,  1.1647,  0.5977],
          ...,
          [ 0.7222, -0.2479, -0.4524,  ...,  0.8263,  1.2958,  0.6242],
          [-0.7043, -0.5573, -0.9992,  ...,  0.7991,  0.9773,  0.7021],
          [-0.7664, -0.6398, -0.3093,  ...,  0.9467,  1.2165,  0.4633]],
 
         [[-0.2689, -0.0183, -0.5378,  ...,  0.6198,  1.1503,  0.7082],
          [-0.4604, -0.2877, -0.8647,  ...,  0.8965,  0.9679,  0.4046],
          [ 1.3429,  0.9516,  1.3404,  ...,  1.0460,  1.1793,  0.9541],
          ...,
          [-0.6682, -0.2119, -0.2083,  ...,  0.6230,  1.5158,  0.8301],
          [-0.0754,  0.2375,  0.7344,  ...,  0.9857,  1.5763,  0.4417],
          [ 0.5742,  1.3351,  1.1841,  ...,  1.1348,  1.2267,  0.6618]]],
        grad_fn=<GatherBackward0>),
 torch.Size([2, 49, 768]),
 tensor([[1., 1., 0., 0., 1., 1., 1., 

In [50]:
vit_embeddings.cls_token, vit_embeddings.cls_token.shape

(Parameter containing:
 tensor([[[ 1.0774e-02,  7.7184e-03,  5.8833e-03,  8.9845e-03, -3.4352e-02,
            3.4248e-03,  4.7736e-03,  1.3227e-02,  1.8124e-02, -1.5146e-02,
           -2.5704e-02,  1.8151e-02, -1.5896e-02, -3.8172e-02,  1.8153e-02,
           -1.9877e-02, -9.5269e-04, -4.5868e-02, -3.2151e-03,  1.2454e-02,
           -4.6562e-02, -2.1301e-04, -3.2736e-02, -1.6554e-02, -2.0955e-02,
           -4.7700e-03,  4.0389e-02,  2.0755e-02,  3.0682e-02,  1.7040e-02,
            8.8829e-03,  8.8651e-03,  1.8762e-02,  1.7758e-03,  2.7143e-03,
            1.2448e-02,  5.2147e-02, -1.2137e-02,  3.9249e-03, -5.1513e-02,
            1.9513e-02,  8.2119e-03,  7.0421e-03, -3.1023e-05, -1.2054e-02,
            3.2577e-02,  2.4956e-02,  1.1123e-02,  2.2424e-03,  2.4977e-03,
           -9.4353e-03,  8.4229e-03, -2.2655e-02, -1.4874e-02,  1.3114e-02,
            1.1049e-02, -1.7435e-02, -3.5369e-02,  2.3251e-03,  1.9796e-02,
            7.8293e-03,  2.6074e-03, -2.9100e-04, -8.9573e-03,  8

In [51]:
cls_token_pos = vit_embeddings.cls_token + vit_embeddings.position_embeddings[:, :1, :]
cls_token_pos, cls_token_pos.shape

(tensor([[[ 1.0774e-02,  7.7184e-03,  5.8833e-03,  8.9845e-03, -3.4352e-02,
            3.4248e-03,  4.7736e-03,  1.3227e-02,  1.8124e-02, -1.5146e-02,
           -2.5704e-02,  1.8151e-02, -1.5896e-02, -3.8172e-02,  1.8153e-02,
           -1.9877e-02, -9.5269e-04, -4.5868e-02, -3.2151e-03,  1.2454e-02,
           -4.6562e-02, -2.1301e-04, -3.2736e-02, -1.6554e-02, -2.0955e-02,
           -4.7700e-03,  4.0389e-02,  2.0755e-02,  3.0682e-02,  1.7040e-02,
            8.8829e-03,  8.8651e-03,  1.8762e-02,  1.7758e-03,  2.7143e-03,
            1.2448e-02,  5.2147e-02, -1.2137e-02,  3.9249e-03, -5.1513e-02,
            1.9513e-02,  8.2119e-03,  7.0421e-03, -3.1023e-05, -1.2054e-02,
            3.2577e-02,  2.4956e-02,  1.1123e-02,  2.2424e-03,  2.4977e-03,
           -9.4353e-03,  8.4229e-03, -2.2655e-02, -1.4874e-02,  1.3114e-02,
            1.1049e-02, -1.7435e-02, -3.5369e-02,  2.3251e-03,  1.9796e-02,
            7.8293e-03,  2.6074e-03, -2.9100e-04, -8.9573e-03,  8.1006e-03,
           -

In [52]:
cls_tokens = cls_token_pos.expand(embeddings.shape[0], -1, -1)
cls_tokens, cls_tokens.shape

(tensor([[[ 0.0108,  0.0077,  0.0059,  ...,  0.0257,  0.0042, -0.0149]],
 
         [[ 0.0108,  0.0077,  0.0059,  ...,  0.0257,  0.0042, -0.0149]]],
        grad_fn=<ExpandBackward0>),
 torch.Size([2, 1, 768]))

In [53]:
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings, embeddings.shape

(tensor([[[ 0.0108,  0.0077,  0.0059,  ...,  0.0257,  0.0042, -0.0149],
          [ 0.5278,  0.2685, -0.6473,  ...,  0.8113,  1.0937,  0.7003],
          [-0.5090, -0.2033, -0.3931,  ...,  0.9461,  1.2617,  0.8079],
          ...,
          [ 0.7222, -0.2479, -0.4524,  ...,  0.8263,  1.2958,  0.6242],
          [-0.7043, -0.5573, -0.9992,  ...,  0.7991,  0.9773,  0.7021],
          [-0.7664, -0.6398, -0.3093,  ...,  0.9467,  1.2165,  0.4633]],
 
         [[ 0.0108,  0.0077,  0.0059,  ...,  0.0257,  0.0042, -0.0149],
          [-0.2689, -0.0183, -0.5378,  ...,  0.6198,  1.1503,  0.7082],
          [-0.4604, -0.2877, -0.8647,  ...,  0.8965,  0.9679,  0.4046],
          ...,
          [-0.6682, -0.2119, -0.2083,  ...,  0.6230,  1.5158,  0.8301],
          [-0.0754,  0.2375,  0.7344,  ...,  0.9857,  1.5763,  0.4417],
          [ 0.5742,  1.3351,  1.1841,  ...,  1.1348,  1.2267,  0.6618]]],
        grad_fn=<CatBackward0>),
 torch.Size([2, 50, 768]))