# 逐行讲解Masked_AutoEncoder(MAE)的PyTorch代码

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_43_逐行讲解Masked_AutoEncoder(MAE)的PyTorch代码：

https://www.bilibili.com/video/BV1JS4y1N7XE/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

***论文***

Masked Autoencoders Are Scalable Vision Learners：

https://arxiv.org/pdf/2111.06377.pdf

***代码***

MAE：

https://github.com/facebookresearch/mae

# MAE code

## data preprocess
* image2tensor
  * RGB 3 channels
  * PIL.Image.open+convert("RGB"), or torchvision, datas ets.ImageFolder
  * shape:(C,H,W),dtype:uint8
    * unsigned integer 8 bit
    * #000000
    * #FFFFFF
* augment
  * Crop/Resize/Flip
* convert (将uint8转换为[0,1]之间的浮点数)
  * torchvision.transforms.ToPLIImage
  * torchvision.transforms.PILToTensor()
  * [0,1]
* normalize
  * (image-mean)/std,global-level
  * imagenet1k
    * mean:[0.485,0.456,0.406]
    * std:[0.229,0.224,0.225]

## model
* encoder
  * image2patch2embedding
  * position embedding
  * random masking(shuffle)
  * class token
  * Transformer Blocks(ViT-base/Vit-large/Vit-huge)
* decoder
  * projection_layer
  * unshuffle
  * position embedding
  * Transformer Blocks(shallow)
  * regression layer
  * mse loss function(norm pixel)
* forward functions
  * forward encoder
  * forward decoder
  * forward loss

## training
* dataset
* data_loader
* model
* optimizer
* load_model
  * model.state_dict()
  * optimizer.state_dict()
  * epoch
* train_one_epoch
* save_model
  * model.state_dict()
  * optimizer.state_dict()
  * epoch/loss
  * config

## finetuning
* strong augmentation
* build encoder + BN + MLP classifier head
* interpolate position embedding
* load pre-trained model(strict=False)
* update all parameters
* AdamW optimizer
* label smoothing cross-entropy loss

## linear probing
* weak augmentation
* build encoder + BN(no affine) + MLP classifier head
* interpolate position embedding
* only update parameters of MLP classifier head
* LARS optimizer
* cross-entropy loss

## evaluation
* with torch.no_grad()
  * efficient
* model.eval()
  * accurate BN/dropout
* top_k
  * top_1
  * top_5

## 演示shuffle

In [3]:
import torch
x = torch.rand(5)
print(x)
idx_shuffle = torch.argsort(x)
print(idx_shuffle)
idx_restore = torch.argsort(idx_shuffle)
print(idx_restore)

tensor([0.2304, 0.9777, 0.2777, 0.2373, 0.9126])
tensor([0, 3, 2, 4, 1])
tensor([0, 4, 2, 1, 3])
