## i-MAE: Visualization Demo
This is a visualization demo of our pre-trained i-MAE model only reconstructing the subordinate image. The notebook is based off of: https://github.com/facebookresearch/mae.

### Prepare
Check environment. Install packages if in Colab.

In [None]:
import sys
# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/vision-learning-acceleration-lab/i-mae.git
    sys.path.append('./i-mae')
import os
import requests
import torchvision
import timm
import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

import models_mae

### Define utils
define the utils

In [None]:
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_small_patch8_dec1928b'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model, mask_ratio):
    x = torch.tensor(img)

    # make it a batch-like
    new_x = []
    for i in x:
        i = i.unsqueeze(dim=0)
        i = torch.einsum('nhwc->nchw', i)
        i = i.float()
        new_x.append(i)
    # run MAE
    loss, y, mask = model(new_x, weak_idx=0, mask_ratio=mask_ratio)
    y = model.unpatchify(y[0])
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    new_x[0] = torch.einsum('nchw->nhwc', new_x[0])
    new_x[1] = torch.einsum('nchw->nhwc', new_x[1])
    new_x[2] = torch.einsum('nchw->nhwc', new_x[2])
    # masked image
    im_masked = new_x[0] * (1 - mask) + ((1-torch.Tensor(imagenet_mean))/torch.Tensor(imagenet_std)) * mask

    # MAE reconstruction pasted with visible patches
    im_paste = new_x[1] * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 6, 1)
    show_image(x[0], "mixture input")

    plt.subplot(1, 6, 2)
    show_image(im_masked[0], "input + mask")

    plt.subplot(1, 6, 3)
    show_image(y[0], "subordinate reconstruction")

    plt.subplot(1, 6, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.subplot(1, 6, 5)
    show_image(x[1], "subordinate target")

    plt.subplot(1, 6, 6)
    show_image(x[2], "dominant target")
    plt.show()

### Load an image
load an image

In [None]:
img_url_1 = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
img_url_2 = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851

im_1 = Image.open(requests.get(img_url_1, stream=True).raw)
im_1 = im_1.resize((224, 224))
im_1 = np.array(im_1) / 255.
im_2 = Image.open(requests.get(img_url_2, stream=True).raw)
im_2 = im_2.resize((224, 224))
im_2 = np.array(im_2) / 255.

im_1 = im_1 - imagenet_mean
im_1 = im_1 / imagenet_std
im_2 = im_2 - imagenet_mean
im_2 = im_2 / imagenet_std

### Load a pre-trained MAE model

This is an i-MAE trained model with pixels as targets for visualization (ViT-base, training mask ratio=0.5)

In [None]:
# download checkpoint if not exist
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1rpXg8r15cpAtTK4QkaNY725KZ9jTJpBL' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1rpXg8r15cpAtTK4QkaNY725KZ9jTJpBL" -O vit-base-sub-imae.pth && rm -rf /tmp/cookies.txt

chkpt_dir = 'vit-base-sub-imae.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

mask_ratio = 0.5
mix_ratio = 0.3

torch.manual_seed(4)
print('MAE with pixel reconstruction mix ratio:{}, maskratio:{}'.format(mix_ratio, mask_ratio))

img_comb = mix_ratio * im_1  + (1-mix_ratio) *im_2
imgs = [img_comb, im_1, im_2]
run_one_image(imgs, model_mae,  mask_ratio=mask_ratio)