In [3]:
!git clone https://github.com/jeonsworld/ViT-pytorch.git
%cd ViT-pytorch/
!pip install ml-collections
!pip install einops
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-AxL45qSt354FadCyK375_60ehXZfCJs' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-AxL45qSt354FadCyK375_60ehXZfCJs" -O cifar10-100_500_checkpoint.bin && rm -rf /tmp/cookies.txt

Cloning into 'ViT-pytorch'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 170 (delta 32), reused 27 (delta 27), pack-reused 130[K
Receiving objects: 100% (170/170), 21.31 MiB | 8.28 MiB/s, done.
Resolving deltas: 100% (85/85), done.
/content/ViT-pytorch
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 5.1 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=9620ddead2e4c8143ca2c4fb0efc74affdd2753994e1de58ee66cb621dde9f23
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
Successfully built ml-collections
Installing collected packages: ml-collections
Successfully install

In [4]:
from __future__ import absolute_import, division, print_function

import logging
import argparse
import os
import random
import numpy as np

from datetime import timedelta

import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

In [5]:
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", default="cifar10",
                    help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10",
                    help="Which downstream task.")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
                                              "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
                    default="ViT-B_16",
                    help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz",
                    help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="cifar10-100_500_checkpoint.bin")
parser.add_argument("--output_dir", default="output", type=str,
                    help="The output directory where checkpoints will be written.")

parser.add_argument("--img_size", default=224, type=int,
                    help="Resolution size")
parser.add_argument("--train_batch_size", default=512, type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int,
                    help="Total batch size for eval.")
parser.add_argument("--eval_every", default=100, type=int,
                    help="Run prediction on validation set every so many steps."
                          "Will always run one evaluation at the end of training.")

parser.add_argument("--learning_rate", default=3e-2, type=float,
                    help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
                    help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=10000, type=int,
                    help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
                    help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
                    help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
                    help="Max gradient norm.")

parser.add_argument("--local_rank", type=int, default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
                    help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
                    help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                          "See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
                    help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                          "0 (default value): dynamic loss scaling.\n"
                          "Positive power of 2: static loss scaling value.\n")
args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend='nccl',
                                          timeout=timedelta(minutes=60))
    args.n_gpu = 1

args.name = 'cifar10'
args.device = device

In [6]:
config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
model.load_state_dict(torch.load(args.pretrained_model))
model.to(args.device)
train_loader, test_loader = get_loader(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def test(model, test_loader):
  model.eval()
  all_preds, all_label = [], []
  epoch_iterator = tqdm(test_loader,
                        desc="Validating... (loss=X.X)",
                        bar_format="{l_bar}{r_bar}",
                        dynamic_ncols=True,
                        disable=args.local_rank not in [-1, 0])
  loss_fct = torch.nn.CrossEntropyLoss()
  for step, batch in enumerate(epoch_iterator):
      batch = tuple(t.to(args.device) for t in batch)
      x, y = batch
      with torch.no_grad():
          logits = model(x)[0]

          eval_loss = loss_fct(logits, y)
          # eval_losses.update(eval_loss.item())

          preds = torch.argmax(logits, dim=-1)

      if len(all_preds) == 0:
          all_preds.append(preds.detach().cpu().numpy())
          all_label.append(y.detach().cpu().numpy())
      else:
          all_preds[0] = np.append(
              all_preds[0], preds.detach().cpu().numpy(), axis=0
          )
          all_label[0] = np.append(
              all_label[0], y.detach().cpu().numpy(), axis=0
          )
      # epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

  all_preds, all_label = all_preds[0], all_label[0]
  accuracy = simple_accuracy(all_preds, all_label)
  print("Valid Accuracy: %2.5f" % accuracy)

In [8]:
test(model, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [00:34<00:00,  4.58it/s]

Valid Accuracy: 0.98940





In [9]:
!git clone https://github.com/Eurus-Holmes/DeiT-CIFAR.git

Cloning into 'DeiT-CIFAR'...
remote: Enumerating objects: 28, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 28 (delta 8), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (28/28), done.


In [10]:
%cd DeiT-CIFAR/

/content/ViT-pytorch/DeiT-CIFAR


In [11]:
!pip install timm==0.4.12

Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[?25l[K     |▉                               | 10 kB 34.1 MB/s eta 0:00:01[K     |█▊                              | 20 kB 40.6 MB/s eta 0:00:01[K     |██▋                             | 30 kB 23.5 MB/s eta 0:00:01[K     |███▌                            | 40 kB 14.2 MB/s eta 0:00:01[K     |████▍                           | 51 kB 12.6 MB/s eta 0:00:01[K     |█████▏                          | 61 kB 14.8 MB/s eta 0:00:01[K     |██████                          | 71 kB 14.3 MB/s eta 0:00:01[K     |███████                         | 81 kB 13.9 MB/s eta 0:00:01[K     |███████▉                        | 92 kB 15.3 MB/s eta 0:00:01[K     |████████▊                       | 102 kB 15.2 MB/s eta 0:00:01[K     |█████████▋                      | 112 kB 15.2 MB/s eta 0:00:01[K     |██████████▍                     | 122 kB 15.2 MB/s eta 0:00:01[K     |███████████▎                    | 133 kB 15.2 MB/s et

In [15]:
!python main.py --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth --data-set CIFAR --model deit_tiny_patch16_224 --batch-size 64 --epochs 13 --output_dir output

Not using distributed mode
Namespace(aa='rand-m9-mstd0.5-inc1', batch_size=64, clip_grad=None, color_jitter=0.4, cooldown_epochs=10, cutmix=1.0, cutmix_minmax=None, data_path='/datasets01/imagenet_full_size/061417/', data_set='CIFAR', decay_epochs=30, decay_rate=0.1, device='cuda', dist_eval=False, dist_url='env://', distillation_alpha=0.5, distillation_tau=1.0, distillation_type='none', distributed=False, drop=0.0, drop_path=0.1, epochs=13, eval=False, finetune='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', inat_category='name', input_size=224, lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8, mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='deit_tiny_patch16_224', model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, momentum=0.9, num_workers=10, opt='adamw', opt_betas=None, opt_eps=1e-08, output_dir='output', patience_epochs=10, pin_mem=True, recount=1, remode='pixel', repeated_aug=True, 

In [None]:
#if running on a local divice, comment these lines
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
