In [1]:
from model.pytorch_pretrained_vit import ViT
from regularizer import prune_vit
import dataloader

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from copy import deepcopy

from utils import AverageMeter, get_weights_copy

import time

import torch.nn as nn
import torch.optim as optim

import torch
import os

In [2]:
BASE_PATH = '/workspace/paper_works/work_results/'

In [3]:
def train_iter(net, loader, criterion, optimizer, epoch, device, reg_ratio=0., print_freq=200):
    losses = AverageMeter()
    batch_time = AverageMeter()
    net.train()
    for i, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)
        
        pred = net(X)
        loss = criterion(pred, y)
        
        losses.update(loss.data, X.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if reg_ratio > 0.:
            idxs, lams = prune_vit(net, reg_ratio)
        
        if i % print_freq==0 or len(loader) - 1 == i:
            print(' Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                                epoch, i, len(loader), loss=losses))
    
    if reg_ratio > 0.:
        idxs, lams = prune_vit(net, reg_ratio)
    
    state_dicts = get_weights_copy(net, device)
    
    state = {
        'epoch': epoch + 1,
        'state_dict': state_dicts,
        'loss': losses.avg.detach().cpu().numpy().tolist(),
    }
    
    if reg_ratio > 0.:
        state['reg_idxs'] = idxs
        state['reg_lams'] = lams
    
    return state

def valid_iter(net, loader, criterion, optimizer, epoch, device, print_freq=100):
    losses = AverageMeter()
    net.eval()
    for i, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)
        
        pred = net(X)
        loss = criterion(pred, y)
        
        losses.update(loss.data, X.size(0))
        
        if i % print_freq==0 or len(loader) - 1 == i:
            print(' Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                                epoch, i, len(loader), loss=losses))
        
    
    state = {
        'loss': losses.avg.detach().cpu().numpy().tolist(),
    }
    
    return state

def main(device, net_name, dataset_args, dataset_name, early_stopping_step,
         lr, model_args, model_name, EPOCH=100, reg_ratio=0.):
    
    file_name = 'best_state.ptl'    
    if reg_ratio == 0:
        net = ViT(model_name, True, **model_args)
    else:
        net = ViT(model_name, False, **model_args)
        state = torch.load('{}/{}_{}/{}'.format(BASE_PATH, dataset_name, net_name, file_name))
        net.load_state_dict(state['state_dict'])
#         net = ViT(model_name, True, **model_args)
        file_name = 'best_state_ratio%d.ptl' % int(reg_ratio * 100)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr)

    trainset, testset, num_classes = getattr(dataloader, dataset_name)(**dataset_args)
    
    best_state = None
    best_loss = 9999999
    early_stopping_idx = 0
    for epoch in range(EPOCH):
        early_stopping_idx += 1
        train_state = train_iter(net, trainset, criterion, optimizer, epoch, device, reg_ratio)
        valid_state = valid_iter(net, testset, criterion, optimizer, epoch, device)

        if best_loss > valid_state['loss']:
            best_loss = valid_state['loss']
            best_state = train_state

            early_stopping_idx = 0

        if early_stopping_idx > early_stopping_step:
            print('early stopping.')
            break

    if not os.path.exists('{}/{}_{}/'.format(BASE_PATH, dataset_name, net_name)):
        os.mkdir('{}/{}_{}/'.format(BASE_PATH, dataset_name, net_name))

    torch.save(best_state, '{}/{}_{}/{}'.format(BASE_PATH, dataset_name, net_name, file_name))
    
    del net

In [None]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'cifar100',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 100
        'image_size' : 32,
        'in_channels' : 3,
    },
    'model_name' : 'B_16_imagenet1k',
}

for ratio in [0.1, 0.3, 0.5, 0.7, 0.9]:
    args['reg_ratio']  = ratio
    main(**args)

Using augmented CIFAR 100.
Files already downloaded and verified
 Epoch: [0][0/391]	Loss 1.1289 (1.1289)	
 Epoch: [0][200/391]	Loss 1.6259 (1.4537)	
 Epoch: [0][390/391]	Loss 1.4008 (1.4395)	
 Epoch: [0][0/79]	Loss 2.1181 (2.1181)	
 Epoch: [0][78/79]	Loss 1.2113 (1.8565)	
 Epoch: [1][0/391]	Loss 0.9945 (0.9945)	
 Epoch: [1][200/391]	Loss 1.1876 (1.2497)	
 Epoch: [1][390/391]	Loss 0.9846 (1.2805)	
 Epoch: [1][0/79]	Loss 1.6766 (1.6766)	
 Epoch: [1][78/79]	Loss 2.1892 (1.8709)	
 Epoch: [2][0/391]	Loss 1.1250 (1.1250)	
 Epoch: [2][200/391]	Loss 1.0516 (1.1148)	
 Epoch: [2][390/391]	Loss 1.0814 (1.1508)	
 Epoch: [2][0/79]	Loss 1.9108 (1.9108)	
 Epoch: [2][78/79]	Loss 1.1043 (1.8455)	
 Epoch: [3][0/391]	Loss 0.9673 (0.9673)	
 Epoch: [3][200/391]	Loss 1.1658 (1.0100)	
 Epoch: [3][390/391]	Loss 1.1948 (1.0442)	
 Epoch: [3][0/79]	Loss 1.9645 (1.9645)	
 Epoch: [3][78/79]	Loss 1.7141 (1.8941)	
 Epoch: [4][0/391]	Loss 0.8583 (0.8583)	
 Epoch: [4][200/391]	Loss 0.9809 (0.9258)	
 Epoch: [4][390/391

 Epoch: [16][200/391]	Loss 0.8422 (0.6801)	


```{python}
device = 'cuda:5'
net_name = 'vit'
dataset_name = 'cifar10'
batch_size = 128
early_stopping_step = 5
lr = 0.0005

model_args = {
    'num_classes' : 10,
    'image_size' : 32,
    'in_channels' : 3    ,
}
model_name = 'B_16_imagenet1k'

net = ViT(model_name, True, **model_args)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

trainset, testset, num_classes = cifar10(batch_size=batch_size)
```

## Train ViT Model
```{python}

best_state = None
best_loss = 9999999
early_stopping_idx = 0
for epoch in range(100):
    early_stopping_idx += 1
    train_state = train_iter(net, trainset, criterion, optimizer, epoch, device)
    valid_state = valid_iter(net, testset, criterion, optimizer, epoch, device)
    
    if best_loss > valid_state['loss']:
        best_loss = valid_state['loss']
        best_state = train_state
        
        early_stopping_idx = 0
        
    if early_stopping_idx > early_stopping_step:
        print('early stopping.')
        break
        
if not os.path.exists('{}/{}_{}/'.format(BASE_PATH, dataset_name, net_name)):
    os.mkdir('{}/{}_{}/'.format(BASE_PATH, dataset_name, net_name))

torch.save(best_state, '{}/{}_{}/best_state.ptl'.format(BASE_PATH, dataset_name, net_name))
```

In [5]:
# fmnist, mnist, imagenet, cifar10, cifar100

In [7]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'fmnist',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 28,
        'in_channels' : 1,
    },
    'model_name' : 'B_16_imagenet1k',
}

main(**args)

Resized positional embeddings from torch.Size([1, 577, 768]) to torch.Size([1, 2, 768])
Loaded pretrained weights.
 Epoch: [0][0/469]	Loss 2.3026 (2.3026)	
 Epoch: [0][200/469]	Loss 0.8916 (1.1547)	
 Epoch: [0][400/469]	Loss 0.7908 (0.9405)	
 Epoch: [0][0/79]	Loss 0.5761 (0.5761)	
 Epoch: [1][0/469]	Loss 0.8311 (0.8311)	
 Epoch: [1][200/469]	Loss 0.6708 (0.6294)	
 Epoch: [1][400/469]	Loss 0.7531 (0.6185)	
 Epoch: [1][0/79]	Loss 0.5679 (0.5679)	
 Epoch: [2][0/469]	Loss 0.5744 (0.5744)	
 Epoch: [2][200/469]	Loss 0.5523 (0.5655)	
 Epoch: [2][400/469]	Loss 0.6629 (0.5607)	
 Epoch: [2][0/79]	Loss 0.5592 (0.5592)	
 Epoch: [3][0/469]	Loss 0.3930 (0.3930)	
 Epoch: [3][200/469]	Loss 0.6584 (0.5280)	
 Epoch: [3][400/469]	Loss 0.4988 (0.5261)	
 Epoch: [3][0/79]	Loss 0.6163 (0.6163)	
 Epoch: [4][0/469]	Loss 0.5425 (0.5425)	
 Epoch: [4][200/469]	Loss 0.6367 (0.5121)	
 Epoch: [4][400/469]	Loss 0.4724 (0.5075)	
 Epoch: [4][0/79]	Loss 0.4423 (0.4423)	
 Epoch: [5][0/469]	Loss 0.4794 (0.4794)	
 Epoch: [

In [None]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'fmnist',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 28,
        'in_channels' : 1,
    },
    'model_name' : 'B_16_imagenet1k',
}

for ratio in [0.1, 0.3, 0.5, 0.7, 0.9]:
    args['reg_ratio']  = ratio
    main(**args)

In [8]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'mnist',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 28,
        'in_channels' : 1,
    },
    'model_name' : 'B_16_imagenet1k',
}

main(**args)

Resized positional embeddings from torch.Size([1, 577, 768]) to torch.Size([1, 2, 768])
Loaded pretrained weights.
 Epoch: [0][0/469]	Loss 2.3026 (2.3026)	
 Epoch: [0][200/469]	Loss 0.6219 (1.0719)	
 Epoch: [0][400/469]	Loss 0.5050 (0.8242)	
 Epoch: [0][0/79]	Loss 0.3814 (0.3814)	
 Epoch: [1][0/469]	Loss 0.4314 (0.4314)	
 Epoch: [1][200/469]	Loss 0.4443 (0.4555)	
 Epoch: [1][400/469]	Loss 0.5112 (0.4329)	
 Epoch: [1][0/79]	Loss 0.3165 (0.3165)	
 Epoch: [2][0/469]	Loss 0.2767 (0.2767)	
 Epoch: [2][200/469]	Loss 0.3829 (0.3612)	
 Epoch: [2][400/469]	Loss 0.3456 (0.3600)	
 Epoch: [2][0/79]	Loss 0.4342 (0.4342)	
 Epoch: [3][0/469]	Loss 0.2976 (0.2976)	
 Epoch: [3][200/469]	Loss 0.2563 (0.3286)	
 Epoch: [3][400/469]	Loss 0.3472 (0.3260)	
 Epoch: [3][0/79]	Loss 0.3588 (0.3588)	
 Epoch: [4][0/469]	Loss 0.4007 (0.4007)	
 Epoch: [4][200/469]	Loss 0.2141 (0.2958)	
 Epoch: [4][400/469]	Loss 0.3258 (0.2963)	
 Epoch: [4][0/79]	Loss 0.2515 (0.2515)	
 Epoch: [5][0/469]	Loss 0.1945 (0.1945)	
 Epoch: [

In [None]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'mnist',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 28,
        'in_channels' : 1,
    },
    'model_name' : 'B_16_imagenet1k',
}

for ratio in [0.1, 0.3, 0.5, 0.7, 0.9]:
    args['reg_ratio']  = ratio
    main(**args)

In [5]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'imagenet',
    'dataset_args': {
        'batch_size' : 16,
        'classes': 10
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 224,
        'in_channels' : 3,
    },
    'model_name' : 'B_16_imagenet1k',
}

main(**args)

Resized positional embeddings from torch.Size([1, 577, 768]) to torch.Size([1, 197, 768])
Loaded pretrained weights.
Using augmented IMAGENET.
 Epoch: [0][0/813]	Loss 2.3026 (2.3026)	
 Epoch: [0][200/813]	Loss 1.1444 (1.3426)	
 Epoch: [0][400/813]	Loss 1.0616 (1.2252)	
 Epoch: [0][600/813]	Loss 0.8956 (1.1705)	
 Epoch: [0][800/813]	Loss 1.6084 (1.1251)	
 Epoch: [0][0/32]	Loss 1.3876 (1.3876)	
 Epoch: [1][0/813]	Loss 0.4881 (0.4881)	
 Epoch: [1][200/813]	Loss 0.6999 (0.9505)	
 Epoch: [1][400/813]	Loss 0.9489 (0.9373)	
 Epoch: [1][600/813]	Loss 1.0469 (0.9334)	
 Epoch: [1][800/813]	Loss 2.1118 (0.9313)	
 Epoch: [1][0/32]	Loss 0.6431 (0.6431)	
 Epoch: [2][0/813]	Loss 1.0754 (1.0754)	
 Epoch: [2][200/813]	Loss 0.8215 (0.8467)	
 Epoch: [2][400/813]	Loss 0.6404 (0.8831)	
 Epoch: [2][600/813]	Loss 0.6288 (0.8815)	
 Epoch: [2][800/813]	Loss 1.2450 (0.8779)	
 Epoch: [2][0/32]	Loss 0.5252 (0.5252)	
 Epoch: [3][0/813]	Loss 0.9270 (0.9270)	
 Epoch: [3][200/813]	Loss 0.4926 (0.8811)	
 Epoch: [3][40

In [None]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'imagenet',
    'dataset_args': {
        'batch_size' : 16,
        'classes': 10
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 224,
        'in_channels' : 3,
    },
    'model_name' : 'B_16_imagenet1k',
}

for ratio in [0.1, 0.3, 0.5, 0.7, 0.9]:
    args['reg_ratio']  = ratio
    main(**args)

In [None]:
args = {
    'device' : 'cuda:5',
    'net_name' : 'vit_B_16_imagenet1k',
    'dataset_name' : 'cifar10',
    'dataset_args': {
        'batch_size' : 128,
    },
    'early_stopping_step' : 7,
    'lr' : 0.0005,
    'model_args' : {
        'num_classes' : 10,
        'image_size' : 32,
        'in_channels' : 3,
    },
    'model_name' : 'B_16_imagenet1k',
}

for ratio in [0.1, 0.3, 0.5, 0.7, 0.9]:
    args['reg_ratio']  = ratio
    main(**args)