In [None]:
import torch
import torch.nn as nn
import numpy as np
from einops import repeat, rearrange
from einops.layers.torch import Rearrange



from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

In [None]:
class TransformerBlock(nn.Module):
    '''
    Transformer Block combines both the attention module and the feed forward module with layer
    normalization, dropout and residual connections. The sequence of operations is as follows :-

    Input -> LayerNorm1 -> Attention -> Residual -> LayerNorm2 -> FeedForward -> Output
      |                                   |  |                                      |
      |-------------Addition--------------|  |---------------Addition---------------|
    '''

    def __init__(self,
                 embed_dim,
                 heads=3,           # modify 8 to 3
                 mlp_ratio=4,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 activation=nn.GELU,
                 norm_layer = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(embed_dim)
        self.attn = Attention(embed_dim, heads=heads,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio)
        self.norm2 = norm_layer(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = Mlp(in_features=embed_dim, hidden_features=mlp_hidden_dim, act_layer=activation, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x





In [None]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape      # length, batch, dim
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T] #The pacth is not masked [T*0.25, B, C]

        return patches, forward_indexes, backward_indexes

In [None]:
class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))

        #  mask the patch and shuffle it
        self.shuffle = PatchShuffle(mask_ratio)

        # we get the (3, dim, patch, patch)
        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        # laynorm of ViT
        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

In [None]:
class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features) # get patchs by head
        mask = torch.zeros_like(patches)
        mask[T:] = 1  # mask the other pixels to 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches) #
        mask = self.patch2img(mask)

        return img, mask

In [None]:
class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [None]:
import random
import torch
import numpy as np
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
seed = 2022
setup_seed(seed)

In [None]:
import random
import torch
import numpy as np
import os
import math
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm import tqdm
from torchinfo import summary

In [None]:
batch_size = 256
load_batch_size = batch_size
train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def get_random_img(dataset):
    total = len(dataset)
    random = np.random.randint(total)

    img,label = dataset[random]

    return img, label

In [None]:
mask_ratio = 0.75
model = MAE_ViT(mask_ratio = mask_ratio).to(device)

In [None]:
summary(model,(8,3,32,32))

In [None]:
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))

In [None]:
base_learning_rate = 1.5e-4
weight_decay = 5e-2
warmup_epoch = 2
total_epoch = 20
optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate * batch_size / 512, betas=(0.9, 0.95), weight_decay=weight_decay)      #modify 256 to 512
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

In [None]:
model_path =  'model/vit-t-mae.pth'  # Pretrained model

In [None]:
from tqdm import tqdm_notebook as tqdm
from einops import repeat, rearrange
from einops.layers.torch import Rearrange
for e in range(total_epoch):
        model.train()
        losses = []
        train_step = len(dataloader)
        with tqdm(total=train_step,desc=f'Epoch {e+1}/{total_epoch}',postfix=dict,mininterval=0.3) as pbar:
            for step,(img, label) in enumerate(iter(dataloader)):

                img = img.to(device)
                predicted_img, mask = model(img)
                optim.zero_grad()
                loss = torch.mean((predicted_img - img) ** 2 * mask) / mask_ratio
                loss.backward()
                optim.step()
                losses.append(loss.item())

                pbar.set_postfix(**{'Train Loss' : np.mean(losses)})
                pbar.update(1)

        lr_scheduler.step()
        avg_loss = sum(losses) / len(losses)
        writer.add_scalar('mae_loss', avg_loss, global_step=e)
        print(f'In epoch {e}, average traning loss is {avg_loss}.')

        ''' visualize the first 16 predicted images on val dataset'''
        model.eval()
        with torch.no_grad():
            val_img = torch.stack([val_dataset[i][0] for i in range(16)])
            val_img = val_img.to(device)
            predicted_val_img, mask = model(val_img)
            predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)
            img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
            img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)
            writer.add_image('mae_image', (img + 1) / 2, global_step=e)

        ''' save model '''
        torch.save(model, model_path)

In [None]:
# we can get the value of the loss is 0.053 after 20 epoches of training.
pretrained_model_path = model_path

In [None]:
#  Finetune
# MAE finetune model and ViT model are the same. The difference lies in the follow-up processing. ViT extracts cls_token for classification. The MAE finetune model is to mean patches tokens (other than cls tokens) and classify them
setup_seed(seed)

In [None]:
batch_size = load_batch_size = 128
train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)

In [None]:
base_learning_rate = 1e-3
total_epoch = 5
weight_decay = 0.05
warmup_epoch = 5
pretrained_model_path = model_path
output_model_path = 'vit-t-classifier.pth'

In [None]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits


In [None]:
model = ViT_Classifier(model.encoder, num_classes=1).to(device)

In [None]:
summary(model,(8, 3, 32, 32))

In [None]:

if pretrained_model_path is not None:
    model = torch.load(pretrained_model_path, map_location='cpu')
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'pretrain-cls'))
else:
    model = MAE_ViT()
    writer = SummaryWriter(os.path.join('logs', 'cifar10', 'scratch-cls'))

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())

optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate * batch_size / 512, betas=(0.9, 0.999), weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

In [None]:
best_val_acc = 0
step_count = 0
optim.zero_grad()
for e in range(total_epoch):
    model.train()
    losses = []
    acces = []
    train_step = len(train_dataloader)
    with tqdm(total=train_step,desc=f'Train Epoch {e+1}/{total_epoch}',postfix=dict,mininterval=0.3) as pbar:
        for step,(img, label) in enumerate(iter(train_dataloader)):
            img = img.to(device)
            label = label.to(device)
            logits = model(img)
            loss = loss_fn(logits, label)
            acc = acc_fn(logits, label)
            loss.backward()

            optim.step()
            optim.zero_grad()
            losses.append(loss.item())
            acces.append(acc.item())

            pbar.set_postfix(**{'Train Loss' : np.mean(losses),
                      'Tran accs': np.mean(acces)})
            pbar.update(1)
    lr_scheduler.step()
    avg_train_loss = sum(losses) / len(losses)
    avg_train_acc = sum(acces) / len(acces)
    print(f'In epoch {e}, average training loss is {avg_train_loss}, average training acc is {avg_train_acc}.')

    model.eval()
    with torch.no_grad():
        losses = []
        acces = []
        val_step = len(val_dataloader)
        with tqdm(total=val_step,desc=f'Val Epoch {e+1}/{total_epoch}',postfix=dict,mininterval=0.3) as pbar2:
            for img, label in tqdm(iter(val_dataloader)):
                img = img.to(device)
                label = label.to(device)
                logits = model(img)
                loss = loss_fn(logits, label)
                acc = acc_fn(logits, label)
                losses.append(loss.item())
                acces.append(acc.item())
                pbar2.set_postfix(**{'Val Loss' : np.mean(losses),
                          'Val accs': np.mean(acces)})
                pbar2.update(1)
        avg_val_loss = sum(losses) / len(losses)
        avg_val_acc = sum(acces) / len(acces)
        print(f'In epoch {e}, average validation loss is {avg_val_loss}, average validation acc is {avg_val_acc}.')

    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        print(f'saving best model with acc {best_val_acc} at {e} epoch!')
        torch.save(model, output_model_path)

    writer.add_scalars('cls/loss', {'train' : avg_train_loss, 'val' : avg_val_loss}, global_step=e)
    writer.add_scalars('cls/acc', {'train' : avg_train_acc, 'val' : avg_val_acc}, global_step=e)

    #Finally, we save the best model with acc 0.6529865506329114 at 4 epoch!