In [None]:
!nvidia-smi

Install some dependencies:

In [None]:
!pip install einops

# Masked Autoencoder

Implementation of Masked Autoencoders Are Scalable Vision Learners paper on CIFAR datasets.

**Implementation details:** Due to limit resource available, we only test the model on cifar10. We mainly want to reproduce the result that pre-training an ViT with MAE can achieve a better result than directly trained in supervised learning with labels. This should be an evidence of self-supervised learning is more data efficient than supervised learning.

We mainly follow the implementation details in the paper. However, due to difference between Cifar10 and ImageNet, we make some modification:

- We use vit-tiny instead of vit-base.
- Since Cifar10 have only 50k training data, we increase the pretraining epoch from 400 to 2000, and the warmup epoch from 40 to 200. We noticed that, the loss is still decreasing after 2000 epoches.
- We decrease the batch size for training the classifier from 1024 to 128 to mitigate the overfitting.

## Learn to Patchify

### Load Dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# Constants
BUFFER_SIZE = 1024
BATCH_SIZE = 256
IMAGE_SIZE = 48
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
MASK_PROPORTION = 0.75
EPOCHS = 250
DOWNSTREAM_EPOCHS = 250

# Data transformations
train_transform = transforms.Compose([
#     transforms.RandomResizedCrop(IMAGE_SIZE),
#     transforms.RandomHorizontalFlip(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# CIFAR10 dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Extracting Patches

![Transformer Architecture](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*W9ecVUyrjTn2RY9Ufuj88Q.png)

![Patchify](https://miro.medium.com/v2/resize:fit:1160/format:webp/1*f0lDgBvf-nc4IytWmJAEbw.png)

We will focus on the first step of the Vision Transformer, which involves converting images into patches. To achieve this, we will utilize PyTorch’s built-in technique called ‘Unfold.’ We will also explain how Unfold works and demonstrate its application in the context of “patchifying” an image.

The Unfold function in PyTorch enables access to specific parts of a tensor, allowing for further processing. It extracts blocks from a tensor in a sliding manner. It as similar to the max-pooling or average-pooling operation, where a specified block size is processed, and then the operation slides to the next block. Unfold provides the ability to extract values from these sliding blocks. Additionally, the Unfold operation flattens the values within each block. Let’s illustrate this with an example.

![Unfold Illustration](https://miro.medium.com/v2/resize:fit:1018/format:webp/1*yUojSXo2Qrs18jhg-lTNTw.png)


For simplicity, let’s consider a tensor with a shape of [1 x 1 x 3 x 3], where the first dimension is the batch size, the second is the number of channels, and the last two are the height and width of the tensor. If we define the unfold operation with a kernel size of (2,2), a block of this size will slide through the tensor, extracting values, flattening them, and moving to the next position. Initially, it will extract values [1, 2, 4, 5], then slide to [2, 3, 5, 6], and continue this process by sliding down. The resultant tensor would have a shape of [1 x 4 x 4], where 1 is the batch size, 4 represents (patch size * number of channels), and the last dimension indicates the number of blocks it has slid through. The exact formula for the output tensor is provided in the documentation (which I will link at the end). The unfold operation in PyTorch is defined as:

```
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
```

The parameters are similar to what we set in pooling/conv operations in PyTorch.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

class Patches(nn.Module):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        # TODO: Define the Unfold object. Kernel_size and stride should be defined.
        self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

    def forward(self, images):
        # images -> B c h w
        bs, c, h, w = images.shape
        
        images = self.unfold(images) # TODO: use unfold with the images
        # images -> B (c*p*p) L
        
        # Reshaping into the shape we want
        patches = images.view(bs, c, self.patch_size, self.patch_size, -1).permute(0, 4, 1, 2, 3)
        # patches -> ( B no.of patches c p p )
        return patches

    def show_patched_image(self, images, patches):
        idx = np.random.choice(patches.shape[0])  # Randomly select an image from the batch
        print(f"Index selected: {idx}.")

        # Show the original image
        plt.figure(figsize=(4, 4))
        original_image = images[idx].permute(1, 2, 0).cpu().numpy()  # Convert to (H, W, C) format
        plt.imshow(original_image)
        plt.suptitle('Original Image')
        plt.axis("off")
        plt.show()

        # Display patches
        n_p = int(np.sqrt(patches.shape[1]))  # Number of patches along one side
        plt.figure(figsize=(4, 4))
        plt.suptitle('Patches')
        for i, patch in enumerate(patches[idx]):
            ax = plt.subplot(n_p, n_p, i + 1)
            # Permute to move channels to the last dimension (from [3, 6, 6] to [6, 6, 3])
            patch_img = patch.permute(1, 2, 0).cpu().numpy()  # Move channels to the last dimension
            plt.imshow(patch_img)
            plt.axis("off")
        plt.show()

        return idx

    def reconstruct_from_patch(self, patch):
        # Reconstruct the original image from patches
        num_patches = patch.shape[0]
        n_p = int(np.sqrt(num_patches))  # Assuming the patches form a square grid (e.g., 8x8 = 64 patches)

        # Reshape patches to keep the channels dimension first and prepare for stitching the patches
        # The patches are currently in shape [num_patches, 3, 6, 6]
        # We split the patches into rows (i.e., n_p patches per row)

        rows = torch.split(patch, n_p, dim=0)  # Split patches into rows (e.g., 8 rows of 8 patches)

        # Now, for each row, concatenate the patches along the width (dimension 3)
        rows = [torch.cat(list(row), dim=2) for row in rows]  # Concatenate patches in each row along the width

        # Finally, concatenate all the rows along the height (dimension 2)
        reconstructed = torch.cat(rows, dim=1)  # Concatenate rows along the height

        # The result will have shape [3, original_height, original_width]
        return reconstructed
    
patch_layer = Patches(PATCH_SIZE)

In [None]:
# Grab a batch of images from the DataLoader
image_batch = next(iter(train_loader))

images = image_batch[0]  # Get the images from the batch

# Extract patches from the augmented images
patches = patch_layer(images)
print(patches.shape)
# Show the patched image
random_index = patch_layer.show_patched_image(images, patches)

ori_images = patch_layer.reconstruct_from_patch(patches[random_index])

# Show the original image
plt.figure(figsize=(4, 4))
original_image = ori_images.permute(1, 2, 0).cpu().numpy()  # Convert to (H, W, C) format
plt.imshow(original_image)
plt.suptitle('Reconstructed Image')
plt.axis("off")
plt.show()

**Note:** In implementation, it is common to use `torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)` to break the input image into patches and project them into an embedding space, i.e. patchify + linear projection (see ViT architecture).

```
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
```

- Input Channels (3): This refers to the number of input channels, which in this case is 3, corresponding to the RGB channels of an image.
- Output Channels (emb_dim): This refers to the output channels (or embedding dimension). The patches from the image are projected into this embedding space. Each patch will have a feature vector of length emb_dim.
- Kernel Size (patch_size): The convolution’s kernel size is set to patch_size. This means the convolutional layer will “extract” patches of size patch_size × patch_size from the image.
- Stride (patch_size): The stride is also set to patch_size, which ensures that the patches are non-overlapping. Each patch is treated independently and projected to the embedding dimension.

Both the `Patches` class we defined above and the Conv2d method help in extracting patches from an image. However, Conv2d not only extracts patches but also projects them into a new embedding space, which is commonly used in models where we need to represent each patch with learnable features.

## Masking 

Masking is a technique often used in self-supervised learning which involves hiding (or “masking”) parts of the input data (e.g., patches of an image) and then training the model to predict or reconstruct the missing parts. This helps the model learn better representations without needing explicit labels.

There are different approach to do the masking. In this code, we focus on randomly shuffling patches of an image and masking a portion of them based on a specified ratio. The remaining patches are kept for further processing, while the others are discarded.

Overview of the Approach:

1. Patch Shuffling: The patches of an image are shuffled using randomly generated indexes.
2. Masking (Reducing Patches): A portion of the shuffled patches is removed based on the provided ratio.
3. Index Tracking: We keep track of both the shuffled (forward_indexes) and original (backward_indexes) positions of the patches, so the shuffling can be reversed if needed.

**Hint**
Let’s consider a simple example with 4 patches:

1. Original Order: [patch_0, patch_1, patch_2, patch_3]
2. Forward Shuffle (forward_indexes):
 - Suppose forward_indexes = [2, 0, 3, 1].
 After applying forward_indexes, the new shuffled order of the patches would be: [patch_2, patch_0, patch_3, patch_1]
3. Backward Indexes (backward_indexes):
 - To undo this shuffle, we need backward_indexes. In this case, backward_indexes = [1, 3, 0, 2].
 - Applying backward_indexes to the shuffled patches would restore the original order: [patch_0, patch_1, patch_2, patch_3]

In [None]:
def random_indexes(size : int):
    # TODO: Generate forward and backward indexes
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

In [None]:
# Check

original_order = ['patch_0', 'patch_1', 'patch_2', 'patch_3']

# Step 1: Generate forward and backward indexes
forward_indexes, backward_indexes = random_indexes(len(original_order))

# Step 2: Apply forward_indexes to shuffle the original order
shuffled_order = [original_order[i] for i in forward_indexes]

# Step 3: Apply backward_indexes to restore the original order
restored_order = [shuffled_order[i] for i in backward_indexes]

# Print the results
print("Original Order:", original_order)
print("Forward Shuffle (forward_indexes):", forward_indexes)
print("Shuffled Order:", shuffled_order)
print("Backward Indexes (backward_indexes):", backward_indexes)
print("Restored Order:", restored_order)

# Check if the restored order matches the original
assert restored_order == original_order, "Restored order does not match the original!"
print("Check Passed: Restored order matches the original.")

In [None]:
# This function reorders the sequences tensor along the first dimension (patches) based on the given indexes, ensuring the shape matches by expanding indexes to match the sequences tensor.
def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

Example use of `take_indexes` function:

In [None]:
from einops import repeat, rearrange
from einops.layers.torch import Rearrange

# Original tensor of patches: [T, B, C] where T=4, B=1, C=3
sequences = torch.tensor([
    [[0, 0, 0]],  # patch_0
    [[1, 1, 1]],  # patch_1
    [[2, 2, 2]],  # patch_2
    [[3, 3, 3]],  # patch_3
])  # Shape: [T=4, B=1, C=3]

# Forward indexes to shuffle the patches
forward_indexes = torch.tensor([2, 0, 3, 1])  # Shape: [T=4]

# This function reorders the sequences tensor along the first dimension (patches) based on the given indexes
def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

# Reorder patches using take_indexes
# Expand forward_indexes to match sequences and shuffle patches
shuffled_patches = take_indexes(sequences, forward_indexes.view(-1, 1))

# Print original patches and shuffled patches
print("Original Patches:\n", sequences)
print("\nShuffled Patches:\n", shuffled_patches)

Explanation:

- repeat(indexes, 't b -> t b c', c=sequences.shape[-1]): This line reshapes and repeats the indexes tensor to match the size of the sequences tensor for the batch and channel dimensions (B and C).
- torch.gather: The gather operation uses the reshaped indexes to reorder the patches in the sequences tensor.

In [None]:
class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape  # T is the total number of patches
        
        remain_T = int(T * (1 - self.ratio)) # TODO: compute the remaining patches after masking
        
        # TODO: Generate forward and backward indexes for each batch
        indexes = [random_indexes(T) for _ in range(B)]
        
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        
        # Use take_indexes to shuffle the patches
        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T] # TODO Keep only the remaining patches

        return patches, forward_indexes, backward_indexes

In [None]:
# Helper Function to reconstruct the image with black patches for masked regions
def reconstruct_with_mask(image, patches, original_shape, mask_indexes, patch_size):
    B, num_patches, C, P_H, P_W = patches.shape
    reconstructed_image = torch.zeros(original_shape).to(image.device)  # Initialize with zeros (black pixels)
    
    # Unfold the image into patches, fill in only the unmasked ones
    patch_num = 0
    mask_set = set(mask_indexes.flatten().cpu().numpy())
    count = 0

    for i in range(0, original_shape[2], patch_size):
        for j in range(0, original_shape[3], patch_size):
            if count not in mask_set:  # If the patch is not masked
                if patch_num < patches.shape[1]:  # Check to avoid index out of bounds
                    reconstructed_image[:, :, i:i + patch_size, j:j + patch_size] = patches[:, patch_num]
                patch_num += 1
            count += 1

    return reconstructed_image

# Create an example image (48x48x3)
# image = torch.rand(1, 3, 48, 48)  # [Batch, Channels, Height, Width]
image = next(iter(train_loader))[0]
idx = np.random.choice(image.shape[0])  # Randomly select an image from the batch
image = image[idx].unsqueeze(0)

# Set patch size and create a patch layer
PATCH_SIZE = 6  # Extract 6x6 patches
patch_layer = Patches(PATCH_SIZE)

# Extract patches from the image
patches = patch_layer(image)  # [B, num_patches, C, patch_size, patch_size]
print(f"Number of patches: {patches.shape[1]}")  # Should be 64 patches (48x48 image with 6x6 patches)

# Create a PatchShuffle layer with a 25% masking ratio (keep 75% of patches)
shuffle_layer = PatchShuffle(ratio=0.75)

# Reshape the patches to match the input expected by PatchShuffle
# PatchShuffle expects shape [num_patches, batch_size, channels * patch_size * patch_size]
B, num_patches, C, P_H, P_W = patches.shape
patches_for_shuffling = patches.view(B * num_patches, C * P_H * P_W).unsqueeze(1)  # [num_patches * B, 1, C * patch_size * patch_size]

# Shuffle patches and apply masking
masked_patches, forward_indexes, backward_indexes = shuffle_layer(patches_for_shuffling)

# Reshape the masked patches back to [B, num_patches, C, patch_size, patch_size]
remain_T = masked_patches.shape[0]  # Number of remaining patches after masking
masked_patches = masked_patches.view(remain_T, B, C, P_H, P_W).permute(1, 0, 2, 3, 4)  # [B, num_patches_remain, C, patch_size, patch_size]

# Mask indexes for patches (the ones not used)
masked_indexes = forward_indexes[remain_T:]

# Reconstruct the original image with black pixels in place of masked patches
reconstructed_image_with_mask = reconstruct_with_mask(image, masked_patches, image.shape, masked_indexes, PATCH_SIZE)

# Display the original and reconstructed image with masked patches
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image.squeeze(0).permute(1, 2, 0).cpu().numpy())  # Original image
ax[0].set_title('Original')
ax[0].axis("off")

ax[1].imshow(reconstructed_image_with_mask.squeeze(0).permute(1, 2, 0).cpu().numpy())  # Image with masked patches
ax[1].set_title('Masked')
ax[1].axis("off")

plt.show()

## MAE Model

Now, let's built our MAE model

<!-- ![MAE](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*MbmkubC541LdgIfnseaCQw.png) -->
<div style="text-align: center;">
<img src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*MbmkubC541LdgIfnseaCQw.png" width="800">
</div>

Here we are using:
- ViT Tiny
- Masking by shuffle approach. See other example approach [here](https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/selfsup/mae.py#L108-L150)

In [None]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block


class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio) # Class we defined above
        
        # In implementation: Patchify + Projection directly with Conv2D
        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding  # TODO add positional embeddings to the patches + linear projection

        patches, forward_indexes, backward_indexes = self.shuffle(patches) # TODO apply the masking using the random shuffle approach
        
        # Add CLS token as in ViT architecture
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
        
        # Lightweight reconstruction using transformer blocks
        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        # Other approach similar to reconstruct_from_patch in Patches class using einops
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()
        
        # TODO: define the encoder and decoder of the MAE
        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask


if __name__ == '__main__':
    shuffle = PatchShuffle(0.75)
    a = torch.rand(16, 2, 10)
    b, forward_indexes, backward_indexes = shuffle(a)
    print(b.shape)

    img = torch.rand(2, 3, 32, 32)
    encoder = MAE_Encoder()
    decoder = MAE_Decoder()
    features, backward_indexes = encoder(img)
    print(forward_indexes.shape)
    predicted_img, mask = decoder(features, backward_indexes)
    print(predicted_img.shape)
    loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75) # TODO complete the loss
    print(loss)

In [None]:
############################################
# Hyperparameters
SEED = 42
BATCH_SIZE = 4096
MAX_DEVICE_BATCH_SIZE=512
BASE_LEARNING_RATE=1.5e-4
WEIGHT_DECAY=0.05
MASK_RATIO=0.75
TOTAL_EPOCH=2000
WARMUP_EPOCH=200
EARLY_STOP=30
SAVE_MODEL_PATH='vit-t-mae.pt'
############################################

In [None]:
import random

def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard
%tensorboard --logdir logs/cifar10/mae-pretrain

In [None]:
import os
import argparse
import math
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm import tqdm


setup_seed(SEED)

batch_size = BATCH_SIZE
load_batch_size = min(MAX_DEVICE_BATCH_SIZE, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MAE_ViT(mask_ratio=MASK_RATIO).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=BASE_LEARNING_RATE * BATCH_SIZE / 256, betas=(0.9, 0.95), weight_decay=WEIGHT_DECAY)
lr_func = lambda epoch: min((epoch + 1) / (WARMUP_EPOCH + 1e-8), 0.5 * (math.cos(epoch / TOTAL_EPOCH * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

step_count = 0
optim.zero_grad()
for e in range(TOTAL_EPOCH):
    model.train()
    losses = []
    for img, label in tqdm(iter(dataloader)):
        step_count += 1
        img = img.to(device)
        predicted_img, mask = model(img)
        loss = torch.mean((predicted_img - img) ** 2 * mask) / MASK_RATIO # TODO: Complete the image loss
        loss.backward()
        if step_count % steps_per_update == 0:
            optim.step()
            optim.zero_grad()
        losses.append(loss.item())
    lr_scheduler.step()
    avg_loss = sum(losses) / len(losses)
    writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

    ''' visualize the first 16 predicted images on val dataset'''
    model.eval()
    with torch.no_grad():
        val_img = torch.stack([val_dataset[i][0] for i in range(16)])
        val_img = val_img.to(device)
        predicted_val_img, mask = model(val_img)
        predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)
        img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
        img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)
        writer.add_image('mae_image', (img + 1) / 2, global_step=e)

    ''' save model '''
    torch.save(model, SAVE_MODEL_PATH)
    if EARLY_STOP and e == EARLY_STOP:
        print("Early stopping (Training takes too long - remove this by setting EARLY_STOP = None) ....")
        break

## Downstream Task

<div style="text-align: center;">
    <img src="https://i.ibb.co/2FV5YnW/Picture1.png" width="600">
</div>

In [None]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        # We load the components of the MAE Encoder but we don't perform masking
        # We don't use the MAE Decoder part anymore, it was useful just for the pretext task
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        # We add a classification head to perform the classification
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits

In [None]:
# Let's download the pre-trained ViT MAE since pre-train it from scratch will take too much time even using ViT-Tiny and CIFAR1.0
!wget https://github.com/IcarusWizard/MAE/releases/download/cifar10/vit-t-mae.pt
    

In [None]:
############################################
# Hyperparameters Downstream Task
SEED = 42
BATCH_SIZE = 128
MAX_DEVICE_BATCH_SIZE=256
BASE_LEARNING_RATE=1e-3
WEIGHT_DECAY=0.05
TOTAL_EPOCH=100
WARMUP_EPOCH=5
# EARLY_STOP=5
PRETRAINED_MODEL_PATH='vit-t-mae.pt' # If empty then train classifier from scratch (traditional supervised approach)
OUTPUT_MODEL_PATH='vit-t-mae.pt'
############################################

In [None]:
# # Load TensorBoard extension
# %load_ext tensorboard

# # Start TensorBoard
# %tensorboard --logdir logs/cifar10/pretrain-cls

In [None]:
setup_seed(SEED)

batch_size = BATCH_SIZE
load_batch_size = min(MAX_DEVICE_BATCH_SIZE, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if PRETRAINED_MODEL_PATH is not None:
    model = torch.load(PRETRAINED_MODEL_PATH, map_location='cpu')
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'pretrain-cls'))
else:
    model = MAE_ViT()
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'scratch-cls'))
model = ViT_Classifier(model.encoder, num_classes=10).to(device)

loss_fn = torch.nn.CrossEntropyLoss()
acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())

optim = torch.optim.AdamW(model.parameters(), lr=BASE_LEARNING_RATE * BATCH_SIZE / 256, betas=(0.9, 0.999), weight_decay=WEIGHT_DECAY)
lr_func = lambda epoch: min((epoch + 1) / (WARMUP_EPOCH + 1e-8), 0.5 * (math.cos(epoch / TOTAL_EPOCH * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

best_val_acc = 0
step_count = 0
optim.zero_grad()
for e in range(TOTAL_EPOCH):
    model.train()
    losses = []
    acces = []
    for img, label in tqdm(iter(train_dataloader)):
        step_count += 1
        img = img.to(device)
        label = label.to(device)
        logits = model(img)
        loss = loss_fn(logits, label)
        acc = acc_fn(logits, label)
        loss.backward()
        if step_count % steps_per_update == 0:
            optim.step()
            optim.zero_grad()
        losses.append(loss.item())
        acces.append(acc.item())
    lr_scheduler.step()
    avg_train_loss = sum(losses) / len(losses)
    avg_train_acc = sum(acces) / len(acces)
    print(f'In epoch {e}, average training loss is {avg_train_loss}, average training acc is {avg_train_acc}.')

    model.eval()
    with torch.no_grad():
        losses = []
        acces = []
        for img, label in tqdm(iter(val_dataloader)):
            img = img.to(device)
            label = label.to(device)
            logits = model(img)
            loss = loss_fn(logits, label)
            acc = acc_fn(logits, label)
            losses.append(loss.item())
            acces.append(acc.item())
        avg_val_loss = sum(losses) / len(losses)
        avg_val_acc = sum(acces) / len(acces)
        print(f'In epoch {e}, average validation loss is {avg_val_loss}, average validation acc is {avg_val_acc}.')  

    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        print(f'saving best model with acc {best_val_acc} at {e} epoch!')       
        torch.save(model, OUTPUT_MODEL_PATH)

    writer.add_scalars('cls/loss', {'train' : avg_train_loss, 'val' : avg_val_loss}, global_step=e)
    writer.add_scalars('cls/acc', {'train' : avg_train_acc, 'val' : avg_val_acc}, global_step=e)

#### Bonus
- Masked Language Modeling Tutorial: https://www.kaggle.com/code/shreydan/masked-language-modeling-from-scratch