In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!git clone https://github.com/YLtrees2/ViT-pytorch-Low-rank-Approximation.git
%cd ViT-pytorch-Low-rank-Approximation/
!pip install ml-collections
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-AxL45qSt354FadCyK375_60ehXZfCJs' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-AxL45qSt354FadCyK375_60ehXZfCJs" -O cifar10-100_500_checkpoint.bin && rm -rf /tmp/cookies.txt

Cloning into 'ViT-pytorch-Low-rank-Approximation'...
remote: Enumerating objects: 190, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 190 (delta 44), reused 27 (delta 27), pack-reused 130[K
Receiving objects: 100% (190/190), 21.31 MiB | 33.06 MiB/s, done.
Resolving deltas: 100% (97/97), done.
/content/ViT-pytorch-Low-rank-Approximation
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 2.6 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=fc584ad456bf7a81e77cbff6650b4d35a7882945811947bbfe67ed88907dc0d2
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
Successfully built ml-collections
Installing collecte

##### In this model, each self-attention is replaced by the Nyströmformer (a Nyström-Based algorithm for approximating Self-Attention).

The default version uses 32 landmarks(Nystrom) points to reconstruct the soft-max matrix in self-attention.

In [3]:
from __future__ import absolute_import, division, print_function

import logging
import argparse
import os
import random
import numpy as np

from datetime import timedelta

import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from models.modeling_Nystromformer import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

(Use the following block each time after updating the model.)

In [None]:
#import sys, importlib
#importlib.reload(sys.modules['models.modeling_Nystromformer'])
#from models.modeling_Nystromformer import VisionTransformer, CONFIGS

In [4]:
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", default="cifar100",
                    help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar100",
                    help="Which downstream task.")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
                                              "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
                    default="ViT-B_16",
                    help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz",
                    help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="/content/drive/MyDrive/cifar100_checkpoint.bin")
parser.add_argument("--output_dir", default="output", type=str,
                    help="The output directory where checkpoints will be written.")

parser.add_argument("--img_size", default=224, type=int,
                    help="Resolution size")
parser.add_argument("--train_batch_size", default=512, type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int,
                    help="Total batch size for eval.")
parser.add_argument("--eval_every", default=100, type=int,
                    help="Run prediction on validation set every so many steps."
                          "Will always run one evaluation at the end of training.")

parser.add_argument("--learning_rate", default=3e-2, type=float,
                    help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
                    help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=10000, type=int,
                    help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
                    help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
                    help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
                    help="Max gradient norm.")

parser.add_argument("--local_rank", type=int, default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
                    help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
                    help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                          "See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
                    help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                          "0 (default value): dynamic loss scaling.\n"
                          "Positive power of 2: static loss scaling value.\n")
args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend='nccl',
                                          timeout=timedelta(minutes=60))
    args.n_gpu = 1

args.name = 'cifar100'
args.device = device

In [5]:
config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
model.load_state_dict(torch.load(args.pretrained_model, map_location=torch.device(args.device)))
model.to(args.device)
train_loader, test_loader = get_loader(args)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def test(model, test_loader):
  model.eval()
  all_preds, all_label = [], []
  epoch_iterator = tqdm(test_loader,
                        desc="Validating... (loss=X.X)",
                        bar_format="{l_bar}{r_bar}",
                        dynamic_ncols=True,
                        disable=args.local_rank not in [-1, 0])
  loss_fct = torch.nn.CrossEntropyLoss()
  #max_itr = 10
  for step, batch in enumerate(epoch_iterator):
      #if step >= max_itr:
      #  break
      batch = tuple(t.to(args.device) for t in batch)
      x, y = batch
      with torch.no_grad():
          logits = model(x)[0]

          eval_loss = loss_fct(logits, y)
          # eval_losses.update(eval_loss.item())

          preds = torch.argmax(logits, dim=-1)

      if len(all_preds) == 0:
          all_preds.append(preds.detach().cpu().numpy())
          all_label.append(y.detach().cpu().numpy())
      else:
          all_preds[0] = np.append(
              all_preds[0], preds.detach().cpu().numpy(), axis=0
          )
          all_label[0] = np.append(
              all_label[0], y.detach().cpu().numpy(), axis=0
          )
      # epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

  all_preds, all_label = all_preds[0], all_label[0]
  accuracy = simple_accuracy(all_preds, all_label)
  print("Valid Accuracy: %2.5f" % accuracy)

In [7]:
test(model, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [01:01<00:00,  2.55it/s]

Valid Accuracy: 0.50310





In [8]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
size_all_mb

327.58924865722656

Now we also check versions which use 24 and 64 landmarks respectively.

In [9]:
import sys, importlib
importlib.reload(sys.modules['models.modeling_Nystromformer'])
from models.modeling_Nystromformer_24landmarks import VisionTransformer, CONFIGS

config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model24 = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
model24.load_state_dict(torch.load(args.pretrained_model, map_location=torch.device(args.device)))
model24.to(args.device)
train_loader, test_loader = get_loader(args)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
test(model24, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [01:00<00:00,  2.59it/s]

Valid Accuracy: 0.38510





In [11]:
import sys, importlib
importlib.reload(sys.modules['models.modeling_Nystromformer'])

from models.modeling_Nystromformer_64landmarks import VisionTransformer, CONFIGS

config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model64 = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
model64.load_state_dict(torch.load(args.pretrained_model, map_location=torch.device(args.device)))
model64.to(args.device)
train_loader, test_loader = get_loader(args)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
test(model64, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [01:04<00:00,  2.45it/s]

Valid Accuracy: 0.74010





Currently we use 2 iterations of iterative inverse for each self-attention and use the exact coefficient computatio (instead of the original implementation to compute coefficient of Z_0). It is also possible to further try other values combining with the choices of the number of landmarks.