In [40]:
import numpy as np

np.set_printoptions(
    edgeitems=30, linewidth=100000,
    formatter=dict(float=lambda x: "%.3g" % x)
)
import torch
from mae.models_mae import MaskedAutoencoderViT

In [41]:
mae = MaskedAutoencoderViT(
    img_size=224, patch_size=14, in_chans=3,
    embed_dim=1024, depth=24, num_heads=16,
    decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
    mlp_ratio=4.
)

In [7]:
!wget https: // dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_huge.pth

--2023-08-13 15:30:08--  https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.244.140.119, 18.244.140.105, 18.244.140.2, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.244.140.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1318315181 (1.2G) [binary/octet-stream]
Saving to: 'mae_visualize_vit_large_ganloss.pth'

     0K .......... .......... .......... .......... ..........  0%  193K 1h51m
    50K .......... .......... .......... .......... ..........  0% 17.3M 56m19s
   100K .......... .......... .......... .......... ..........  0%  440K 53m48s
   150K .......... .......... .......... .......... ..........  0% 7.13M 41m5s
   200K .......... .......... .......... .......... ..........  0% 7.05M 33m28s
   250K .......... .......... .......... .......... ..........  0% 10.7M 28m13s
   300K .......... .......... .......... .......... ..........  

In [44]:
N = 7
L = 14 * 14
D = 512
mask_ratio = 0.75
masks_per_img = 4

In [45]:
len_keep = int(L * (1 - mask_ratio))
select = torch.diagflat(torch.ones(masks_per_img)).to(torch.device('cuda'))
select = select.repeat(1, L // masks_per_img)
select, select.shape

(tensor([[1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
          0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
          0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
          0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
          0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
          0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
          0., 0., 0., 1

In [46]:
ids_select = torch.argsort(select, dim=1, stable=True)
ids_select, ids_select.shape

(tensor([[  1,   2,   3,   5,   6,   7,   9,  10,  11,  13,  14,  15,  17,  18,
           19,  21,  22,  23,  25,  26,  27,  29,  30,  31,  33,  34,  35,  37,
           38,  39,  41,  42,  43,  45,  46,  47,  49,  50,  51,  53,  54,  55,
           57,  58,  59,  61,  62,  63,  65,  66,  67,  69,  70,  71,  73,  74,
           75,  77,  78,  79,  81,  82,  83,  85,  86,  87,  89,  90,  91,  93,
           94,  95,  97,  98,  99, 101, 102, 103, 105, 106, 107, 109, 110, 111,
          113, 114, 115, 117, 118, 119, 121, 122, 123, 125, 126, 127, 129, 130,
          131, 133, 134, 135, 137, 138, 139, 141, 142, 143, 145, 146, 147, 149,
          150, 151, 153, 154, 155, 157, 158, 159, 161, 162, 163, 165, 166, 167,
          169, 170, 171, 173, 174, 175, 177, 178, 179, 181, 182, 183, 185, 186,
          187, 189, 190, 191, 193, 194, 195,   0,   4,   8,  12,  16,  20,  24,
           28,  32,  36,  40,  44,  48,  52,  56,  60,  64,  68,  72,  76,  80,
           84,  88,  92,  96, 100, 104, 

In [47]:
ids_restore = torch.argsort(ids_select, dim=1, stable=True)
ids_restore, ids_restore.shape

(tensor([[147,   0,   1,   2, 148,   3,   4,   5, 149,   6,   7,   8, 150,   9,
           10,  11, 151,  12,  13,  14, 152,  15,  16,  17, 153,  18,  19,  20,
          154,  21,  22,  23, 155,  24,  25,  26, 156,  27,  28,  29, 157,  30,
           31,  32, 158,  33,  34,  35, 159,  36,  37,  38, 160,  39,  40,  41,
          161,  42,  43,  44, 162,  45,  46,  47, 163,  48,  49,  50, 164,  51,
           52,  53, 165,  54,  55,  56, 166,  57,  58,  59, 167,  60,  61,  62,
          168,  63,  64,  65, 169,  66,  67,  68, 170,  69,  70,  71, 171,  72,
           73,  74, 172,  75,  76,  77, 173,  78,  79,  80, 174,  81,  82,  83,
          175,  84,  85,  86, 176,  87,  88,  89, 177,  90,  91,  92, 178,  93,
           94,  95, 179,  96,  97,  98, 180,  99, 100, 101, 181, 102, 103, 104,
          182, 105, 106, 107, 183, 108, 109, 110, 184, 111, 112, 113, 185, 114,
          115, 116, 186, 117, 118, 119, 187, 120, 121, 122, 188, 123, 124, 125,
          189, 126, 127, 128, 190, 129, 

In [48]:
ids_keep = ids_select[:, :len_keep]
ids_keep, ids_keep.shape

(tensor([[ 1,  2,  3,  5,  6,  7,  9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22, 23,
          25, 26, 27, 29, 30, 31, 33, 34, 35, 37, 38, 39, 41, 42, 43, 45, 46, 47,
          49, 50, 51, 53, 54, 55, 57, 58, 59, 61, 62, 63, 65],
         [ 0,  2,  3,  4,  6,  7,  8, 10, 11, 12, 14, 15, 16, 18, 19, 20, 22, 23,
          24, 26, 27, 28, 30, 31, 32, 34, 35, 36, 38, 39, 40, 42, 43, 44, 46, 47,
          48, 50, 51, 52, 54, 55, 56, 58, 59, 60, 62, 63, 64],
         [ 0,  1,  3,  4,  5,  7,  8,  9, 11, 12, 13, 15, 16, 17, 19, 20, 21, 23,
          24, 25, 27, 28, 29, 31, 32, 33, 35, 36, 37, 39, 40, 41, 43, 44, 45, 47,
          48, 49, 51, 52, 53, 55, 56, 57, 59, 60, 61, 63, 64],
         [ 0,  1,  2,  4,  5,  6,  8,  9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22,
          24, 25, 26, 28, 29, 30, 32, 33, 34, 36, 37, 38, 40, 41, 42, 44, 45, 46,
          48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 64]], device='cuda:0'),
 torch.Size([4, 49]))

In [49]:
x = torch.rand(N, L, D, device=torch.device('cuda'))
x, x.shape

(tensor([[[0.5528, 0.0841, 0.7440,  ..., 0.4665, 0.9507, 0.4640],
          [0.3462, 0.6142, 0.6333,  ..., 0.5306, 0.8530, 0.9411],
          [0.5512, 0.1901, 0.4879,  ..., 0.5767, 0.2706, 0.1304],
          ...,
          [0.6792, 0.5019, 0.5737,  ..., 0.9939, 0.6120, 0.5565],
          [0.8213, 0.5664, 0.1296,  ..., 0.5855, 0.5417, 0.4687],
          [0.0195, 0.3168, 0.7529,  ..., 0.7753, 0.1344, 0.5323]],
 
         [[0.2201, 0.9734, 0.4060,  ..., 0.0545, 0.6258, 0.2451],
          [0.8913, 0.0065, 0.7115,  ..., 0.9488, 0.9827, 0.3569],
          [0.8655, 0.0905, 0.0777,  ..., 0.2366, 0.7689, 0.5585],
          ...,
          [0.9982, 0.0493, 0.7962,  ..., 0.9969, 0.7665, 0.9135],
          [0.8767, 0.2798, 0.0118,  ..., 0.6266, 0.8170, 0.0794],
          [0.3162, 0.7269, 0.8021,  ..., 0.6477, 0.3109, 0.6127]],
 
         [[0.4555, 0.9831, 0.2805,  ..., 0.2821, 0.0814, 0.3664],
          [0.8229, 0.9797, 0.6386,  ..., 0.9889, 0.1556, 0.2799],
          [0.3698, 0.6061, 0.8016,  ...,

In [50]:
i = 0
ids_keep = ids_keep.expand(x.size(0), -1, -1)
ids_restore = ids_restore.expand(x.size(0), -1, -1)
ids_keep.shape, ids_restore.shape

(torch.Size([7, 4, 49]), torch.Size([7, 4, 196]))

In [51]:
x_masked = torch.gather(x, dim=1, index=ids_keep[:, i, :].unsqueeze(-1).repeat(1, 1, D))
x_masked, x_masked.shape

(tensor([[[0.3462, 0.6142, 0.6333,  ..., 0.5306, 0.8530, 0.9411],
          [0.5512, 0.1901, 0.4879,  ..., 0.5767, 0.2706, 0.1304],
          [0.8308, 0.1655, 0.1475,  ..., 0.4480, 0.3579, 0.4627],
          ...,
          [0.1199, 0.0591, 0.8933,  ..., 0.4813, 0.6848, 0.4450],
          [0.0773, 0.7443, 0.2949,  ..., 0.2930, 0.8680, 0.5783],
          [0.6609, 0.9886, 0.5795,  ..., 0.4770, 0.7730, 0.4449]],
 
         [[0.8913, 0.0065, 0.7115,  ..., 0.9488, 0.9827, 0.3569],
          [0.8655, 0.0905, 0.0777,  ..., 0.2366, 0.7689, 0.5585],
          [0.1706, 0.7719, 0.7118,  ..., 0.6757, 0.5710, 0.2129],
          ...,
          [0.9224, 0.9363, 0.9850,  ..., 0.0188, 0.9033, 0.5919],
          [0.4104, 0.8330, 0.3561,  ..., 0.2037, 0.5887, 0.2616],
          [0.0258, 0.8985, 0.1967,  ..., 0.7475, 0.0834, 0.3417]],
 
         [[0.8229, 0.9797, 0.6386,  ..., 0.9889, 0.1556, 0.2799],
          [0.3698, 0.6061, 0.8016,  ..., 0.1815, 0.1700, 0.6211],
          [0.1352, 0.8032, 0.3436,  ...,

In [163]:
# class attribute
# Not needed, equal to `select`
mask = torch.ones((masks_per_img, L), device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore[0, :, :])
mask, mask.shape

(tensor([[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 1., 0., 0.],
         [0., 0., 1.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0'),
 torch.Size([4, 256]))

In [164]:
mask.expand(x.size(0), -1, -1)[:, i, :]

tensor([[1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1.,

---

In [39]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
N = 7
L = 16 * 16
D = 512
mask_ratio = 0.25
masks_per_img = int(1 / mask_ratio)
x = torch.rand(N, 3, 224, 224, device=torch.device('cuda'))

In [36]:
from models_mae import MaskedAutoencoderViT
from pprint import pprint

admae = MaskedAutoencoderViT(
    mask_ratio=0.25,
    img_size=224, patch_size=14, in_chans=3,
    embed_dim=1024, depth=24, num_heads=16,
    decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
    mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, lora_rank=8,
)
admae = admae.to(torch.device('cuda'))

In [38]:
result = admae.inference(x)
pprint(result)
print(result["images"].shape, result["preds"].shape, result["loss_map"].shape)

torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
masks.shape = torch.Size([7, 4, 256])
{'images': tensor([[[[9.2759e-01, 9.4558e-01, 5.4245e-01,  ..., 3.5873e-01,
           2.6744e-01, 5.8835e-01],
          [6.4206e-01, 9.8608e-01, 7.5425e-01,  ..., 4.1731e-01,
           5.3256e-01, 5.1589e-01],
          [4.3933e-01, 4.4575e-01, 1.5426e-01,  ..., 7.4423e-01,
           8.5233e-01, 4.7226e-01],
          ...,
          [3.9860e-01, 7.1466e-03, 7.7495e-01,  ..., 8.3167e-01,
           3.1527e-01, 7.4700e-01],
          [5.3397e-01, 5.6741e-01, 9.0871e-01,  ..., 6.1748e-01,
           8.2876e-01, 2.7384e-01],
          [1.4623e-01, 3.1745e-01, 4.2228e-01,  ..., 3.8052e-01,
           5.3874e-01, 3.8587e-01]],

         [[5.4689e-01, 9.9998e-01, 3.4694e-01,  ..., 7.3897e-01,
           3.9865e-02, 5.9242e-01],
          [7.2696e-01, 3.1226e-01, 5.2659e-01,  ..., 3.7195e-01,
           1.2050e-01, 3.5268e-01],
          [3.3379e-01,

In [None]:
# --- forward_encoder loop ---
xi = x.clone()
xi = admae.patch_embed(xi)
xi = xi + admae.pos_embed[:, 1:, :]
# --- alternate_masking ---
admae.ids_keep

In [17]:
from mae.models_mae import MaskedAutoencoderViT

mae = MaskedAutoencoderViT(
    img_size=224, patch_size=14, in_chans=3,
    embed_dim=1024, depth=24, num_heads=16,
    decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
    mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False
).to('cuda')
mae(x)

torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
torch.Size([7, 256, 588])
masks.shape = torch.Size([7, 4, 256])


(tensor(1.2954, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor([[[[ 0.6853, -0.5156,  1.8680,  ..., -0.9848,  0.5311,  1.1564],
           [ 0.3308, -1.1082,  1.8068,  ..., -0.6639,  0.4419,  1.3918],
           [ 0.5207, -1.0546,  1.6258,  ..., -0.7096,  0.3619,  1.3645],
           ...,
           [ 0.3047, -0.2430,  1.8313,  ..., -1.1082,  0.7122,  1.2494],
           [ 0.4479, -0.1585,  1.8186,  ..., -1.0837,  0.7866,  1.2224],
           [ 0.6047, -0.0855,  1.8078,  ..., -1.1415,  0.9036,  1.1645]],
 
          [[ 0.4856, -1.2099,  1.8013,  ..., -0.7292,  0.5318,  1.4291],
           [ 0.8243, -0.3707,  1.9184,  ..., -0.9004,  0.5904,  1.1985],
           [ 0.5634, -1.1012,  1.6228,  ..., -0.6952,  0.4267,  1.3004],
           ...,
           [ 0.3275, -0.2499,  1.8127,  ..., -1.0983,  0.7351,  1.2450],
           [ 0.4699, -0.1654,  1.7997,  ..., -1.0731,  0.8112,  1.2172],
           [ 0.6261, -0.0930,  1.7885,  ..., -1.1312,  0.9285,  1.1582]],
 
          [[ 0.4398, -1.270