In [1]:
import logging
import argparse
import os
import random
import numpy as np

from datetime import timedelta

import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

#from apex import amp
#from apex.parallel import DistributedDataParallel as DDP

from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

import logging

import torch

from torchvision import transforms, datasets
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler

In [2]:
logger = logging.getLogger(__name__)

In [3]:
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]

    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    model.load_from(np.load(args.pretrained_dir))
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    print(num_params)
    return args, model

In [4]:
config = CONFIGS['ViT-B_16']
num_classes = 10
img_size = 224

model = VisionTransformer(config, img_size, zero_head=True, num_classes=num_classes)

In [5]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
transform_test = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [6]:
trainset = datasets.CIFAR10(root="./data",
                            train=True,
                            download=True,
                            transform=transform_train)
testset = datasets.CIFAR10(root="./data",
                           train=False,
                           download=True,
                           transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_sampler = RandomSampler(trainset)# if args.local_rank == -1 else DistributedSampler(trainset)
test_sampler = SequentialSampler(testset)

train_loader = DataLoader(trainset,
                          sampler=train_sampler,
                          batch_size=4,
                          num_workers=4)
test_loader = DataLoader(testset,
                         sampler=test_sampler,
                         batch_size=1,
                         num_workers=4)

In [8]:
device='cpu'

model.to(device)
print('\n')





In [25]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0)
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=3, t_total=1000)
loss_fct = torch.nn.CrossEntropyLoss()

In [14]:
output = model(trainset[0][0].unsqueeze(0))

In [21]:
output[0][:, 0].size()

torch.Size([1, 768])

In [22]:
output[1].size()

torch.Size([1, 10])

In [61]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

In [100]:
num_epochs=100
for epochs in range(num_epochs):

    model.train()
    #epoch_iterator = tqdm(train_loader,
    #                      desc="Training (X / X Steps) (loss=X.X)",
    #                      bar_format="{l_bar}{r_bar}",
    #                      dynamic_ncols=True,
    #                      disable=args.local_rank not in [-1, 0])
    
    
    train_loss = []
    all_preds, all_label = [], []
    
    for step, batch in enumerate(train_loader):
        
        x = batch[0].to(device)
        y = batch[1].to(device)
        
        output = model(x)[0]
        loss = loss_fct(output, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        
        train_loss += [loss.item()]
        preds = torch.argmax(output, -1)
        
        all_preds += preds.tolist()
        all_label += y.tolist()   
        
        acc = simple_accuracy(np.array(all_preds), np.array(all_label))
        
        print(acc)
        
        #if step % 100 == 0:
        #    log.info("Epoch : {}/{}, Batch : {}/{}, train loss : {}, accuracy : {}".format(epochs+1, num_epochs, step+1, len(train_loader),
        #                                                                  sum(train_loss)/len(train_loss), |acc))
        
        
    test_loss = []
    all_preds, all_label = [], []        
        
    for step, batch in enumerate(test_loader):
        
        x = batch[0].to(device)
        y = batch[1].to(device)
        
        output = model(x)[0]
        loss = loss_fct(output, y)
        
        test_loss += [loss.item()]
        preds = torch.argmax(output, -1)
        
        all_preds.append(preds.tolist())
        all_label.append(y.tolist())        
        
        acc = simple_accuracy(all_preds[0], all_label[0])
        
        if step % 100 == 0:
            log.info("Epoch : {}/{}, Batch : {}/{}, train loss : {}, accuracy : {}".format(epochs+1, num_epochs, step+1, len(test_loader),
                                                                          sum(test_loss)/len(test_loss),
                                                                                          acc))

0.5
0.25


KeyboardInterrupt: 