## Masked Autoencoders: Visualization Demo

This is a visualization demo using our pre-trained MAE models. No GPU is needed.

In [1]:
import json
token={"username":"mcbrandon","key":"0707eb6fd598fbb59105a4020cce3f1c"}

with open('/content/kaggle.json', 'w') as file:
  json.dump(token, file)
!mkdir -p ~/.kaggle
!cp /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle config set -n path -v /content

- path is now set to: /content


In [2]:
#!/bin/bash
!kaggle datasets download ifigotin/imagenetmini-1000

Dataset URL: https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000
License(s): unknown
Downloading imagenetmini-1000.zip to /content/datasets/ifigotin/imagenetmini-1000
100% 3.91G/3.92G [00:52<00:00, 103MB/s]
100% 3.92G/3.92G [00:52<00:00, 79.5MB/s]


In [None]:
!unzip datasets/ifigotin/imagenetmini-1000/imagenetmini-1000.zip

### Prepare
Check environment. Install packages if in Colab.


In [3]:
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/facebookresearch/mae.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae

Running in Colab.
Collecting timm==0.4.5
  Downloading timm-0.4.5-py3-none-any.whl.metadata (24 kB)
Downloading timm-0.4.5-py3-none-any.whl (287 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 1.0.12
    Uninstalling timm-1.0.12:
      Successfully uninstalled timm-1.0.12
Successfully installed timm-0.4.5
Cloning into 'mae'...
remote: Enumerating objects: 39, done.[K
remote: Total 39 (delta 0), reused 0 (delta 0), pack-reused 39 (from 1)[K
Receiving objects: 100% (39/39), 834.95 KiB | 2.44 MiB/s, done.
Resolving deltas: 100% (9/9), done.


KeyboardInterrupt: 

### Define utils

In [None]:
# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean)*255, 0, 255).int().numpy())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.5)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x)

    print(f"mean pert:{torch.abs(deprocess(y.squeeze().permute(2,0,1))-deprocess(x.squeeze().permute(2,0,1))).numpy().mean()}")

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

### Load an image

In [None]:
# load an image
img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
img = Image.open("144.jpg")
img = img.resize((224, 224))
img = np.array(img) / 255.

assert img.shape == (224, 224, 3)

# normalize by ImageNet mean and std
img = img - imagenet_mean
img = img / imagenet_std

plt.rcParams['figure.figsize'] = [5, 5]
show_image(torch.tensor(img))

### Load a pre-trained MAE model

In [None]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')


--2023-03-18 05:21:15--  https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.154.144.13, 18.154.144.102, 18.154.144.87, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.154.144.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1318315181 (1.2G) [binary/octet-stream]
Saving to: ‘mae_visualize_vit_large.pth’


2023-03-18 05:21:23 (159 MB/s) - ‘mae_visualize_vit_large.pth’ saved [1318315181/1318315181]

<All keys matched successfully>
Model loaded.


In [None]:
img = torch_dataset[5].permute(1,2,0)
img.shape

torch.Size([224, 224, 3])

In [None]:
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)

### Run MAE on the image

### Load another pre-trained MAE model

In [None]:
# This is an MAE model trained with an extra GAN loss for more realistic generation (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist

!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

chkpt_dir = 'mae_visualize_vit_large_ganloss.pth'
model_mae_gan = prepare_model('mae_visualize_vit_large_ganloss.pth', 'mae_vit_large_patch16')
print('Model loaded.')

--2023-03-19 02:10:24--  https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.74.142, 172.67.9.4, 104.22.75.142, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.74.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1318315181 (1.2G) [binary/octet-stream]
Saving to: ‘mae_visualize_vit_large_ganloss.pth’


2023-03-19 02:10:51 (47.4 MB/s) - ‘mae_visualize_vit_large_ganloss.pth’ saved [1318315181/1318315181]

<All keys matched successfully>
Model loaded.


In [None]:
#@title 复制、解压CDTA数据集
import json
token={"username":"mcbrandon","key":"0707eb6fd598fbb59105a4020cce3f1c"}

with open('/content/kaggle.json', 'w') as file:
  json.dump(token, file)
!mkdir -p ~/.kaggle
!cp /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle config set -n path -v /content

!kaggle datasets download -d jessicali9530/stanford-cars-dataset
!kaggle datasets download -d quanbk/svhndataset
!kaggle datasets download -d pratt3000/stl10-binary-files
!wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

!unzip datasets/quanbk/svhndataset/svhndataset.zip
!unzip datasets/pratt3000/stl10-binary-files/stl10-binary-files.zip
os.makedirs(f"CDTA_datasets",exist_ok=True)
shutil.copyfile("drive/MyDrive/CDTA_datasets/birds-400.zip", "CDTA_datasets/birds-400.zip")
shutil.copyfile("drive/MyDrive/CDTA_datasets/oxford_flower_102.zip", "CDTA_datasets/oxford_flower_102.zip")

os.makedirs("CDTA_datasets/food_101",exist_ok=True)
os.makedirs("CDTA_datasets/comic_books",exist_ok=True)
os.makedirs("CDTA_datasets/comic_books_train",exist_ok=True)
!unzip CDTA_datasets/birds-400.zip -d CDTA_datasets/birds_400
!unzip CDTA_datasets/oxford_flower_102.zip -d CDTA_datasets/oxford_flower_102
!unzip datasets/jessicali9530/stanford-cars-dataset/stanford-cars-dataset.zip -d CDTA_datasets/
!tar -xzvf cifar-100-python.tar.gz

- path is now set to: /content
Downloading stanford-cars-dataset.zip to /content/datasets/jessicali9530/stanford-cars-dataset
 99% 1.81G/1.82G [00:23<00:00, 63.6MB/s]
100% 1.82G/1.82G [00:23<00:00, 83.5MB/s]
Downloading svhndataset.zip to /content/datasets/quanbk/svhndataset
 99% 1.46G/1.47G [00:17<00:00, 46.0MB/s]
100% 1.47G/1.47G [00:17<00:00, 88.8MB/s]
Downloading stl10-binary-files.zip to /content/datasets/pratt3000/stl10-binary-files
 21% 529M/2.44G [00:10<00:31, 65.3MB/s]

In [17]:
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

--2025-01-15 11:29:09--  https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.154.144.102, 18.154.144.87, 18.154.144.13, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.154.144.102|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1318315181 (1.2G) [binary/octet-stream]
Saving to: ‘mae_visualize_vit_large_ganloss.pth’


2025-01-15 11:29:29 (63.3 MB/s) - ‘mae_visualize_vit_large_ganloss.pth’ saved [1318315181/1318315181]



### Run MAE on the image

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/LiulietLee/CDTA.git

Cloning into 'CDTA'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 80 (delta 29), reused 47 (delta 20), pack-reused 0 (from 0)[K
Receiving objects: 100% (80/80), 27.93 KiB | 1021.00 KiB/s, done.
Resolving deltas: 100% (29/29), done.


In [None]:
!unzip datasets/ifigotin/imagenetmini-1000/imagenetmini-1000.zip

In [None]:
%cd CDTA

/content/drive/MyDrive/Project2023_MAE_attack/CDTA


In [None]:
%cd pretrained/surrogate

/content/drive/MyDrive/Project2023_MAE_attack/CDTA/pretrained/surrogate


In [None]:
!wget https://github.com/LiulietLee/CDTA/releases/download/v1.0/simsiam_bs256_100ep_cst.tar

--2024-11-08 06:39:52--  https://github.com/LiulietLee/CDTA/releases/download/v1.0/simsiam_bs256_100ep_cst.tar
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/568768294/01877802-46ea-4756-8f77-abca81cc07ea?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20241108%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20241108T063952Z&X-Amz-Expires=300&X-Amz-Signature=add80dd699f4c3c5606fe34a92377bd2dd8cd1e74445e1da5b62875dde67a805&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dsimsiam_bs256_100ep_cst.tar&response-content-type=application%2Foctet-stream [following]
--2024-11-08 06:39:52--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/568768294/01877802-46ea-4756-8f77-abca81cc07ea?X-Amz-Algorithm=AWS4-HM

In [6]:
%cd drive/MyDrive/Project2023_MAE_attack

/content/drive/MyDrive/Project2023_MAE_attack


In [None]:
!python KL_FIA.py