# Tester

In [1]:
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.utils.data as data
import torchvision
from tqdm import tqdm
import torch.nn.functional as F
import math

from vit_pooling import ViTPooling


gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[2])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


In [2]:
device = gpu
model_path = './save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep8_lr0.0003.pt'
TIMM_MODEL = 'vit_base_patch16_224_in21k'
BATCH_SIZE = 512
NUM_EPOCHS = 300
NUM_WORKERS = 2
LEARNING_RATE = 0.003

IMAGE_SIZE = 224
PATCH_SIZE = 16
IN_CHANNELS = 3
NUM_CLASSES = 1000
EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 12

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [4]:
class TesterTimm(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = []
        self.losses = []

    def process(self):
        self.build_model()
        self.eval_model()

    def loss_checker(self):
        self.build_model()
        print(f'Steps: {len(self.epochs)}k steps')
        [print(f'Sampling Loss: {i:.3f}') for i in self.losses]
        
    def build_model(self):
        self.model = timm.create_model(TIMM_MODEL, pretrained=True).to(device)
        self.model.num_classes = NUM_CLASSES
        checkpoint = torch.load(model_path)
        self.epochs = checkpoint['epochs']
        self.model.load_state_dict(checkpoint['model'])
        self.losses = checkpoint['losses']
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        print(f'Epochs: {self.epochs[-1]}')
        
    def eval_model(self):
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for i, data in tqdm(enumerate(test_loader, 0), total=len(test_loader)):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'Accuracy of {len(test_set)} test images: {100 * correct / total:.2f} %')

In [5]:
class TesterPaper(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []
        self.cls_token = None

    def process(self):
        self.build_model()
        self.eval_model()
        
    def loss_checker(self):
        self.build_model()
        print(f'Steps: {len(self.epochs)}k steps')
        [print(f'Sampling Loss: {i:.3f}') for i in self.losses]

    def build_model(self):
        self.model = ViTPooling(image_size=IMAGE_SIZE,
                         patch_size=PATCH_SIZE,
                         in_channels=IN_CHANNELS,
                         num_classes=NUM_CLASSES,
                         embed_dim=EMBED_DIM,
                         depth=DEPTH,
                         num_heads=NUM_HEADS,
                         ).to(device)
        checkpoint = torch.load(model_path)
        self.epochs = checkpoint['epochs']
        self.model.load_state_dict(checkpoint['model'])
        self.losses = checkpoint['losses']
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.mlp_head.num_classes}')
        print(f'Epochs: {self.epochs[-1]}')

    def eval_model(self):
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for i, data in tqdm(enumerate(test_loader, 0), total=len(test_loader)):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'Accuracy of {len(test_set)} test images: {100 * correct / total:.2f} %')

In [6]:
if __name__ == '__main__':
    t = TesterTimm()
    t.loss_checker()

Parameter: 102595923
Classes: 1000
Epochs: 8
Steps: 160k steps
Sampling Loss: 2.613
Sampling Loss: 1.651
Sampling Loss: 1.380
Sampling Loss: 1.317
Sampling Loss: 1.484
Sampling Loss: 1.449
Sampling Loss: 1.230
Sampling Loss: 1.124
Sampling Loss: 0.851
Sampling Loss: 0.963
Sampling Loss: 1.351
Sampling Loss: 1.463
Sampling Loss: 1.462
Sampling Loss: 1.623
Sampling Loss: 1.057
Sampling Loss: 0.879
Sampling Loss: 1.083
Sampling Loss: 0.868
Sampling Loss: 0.924
Sampling Loss: 1.283
Sampling Loss: 0.779
Sampling Loss: 0.848
Sampling Loss: 1.110
Sampling Loss: 0.620
Sampling Loss: 1.022
Sampling Loss: 1.047
Sampling Loss: 0.864
Sampling Loss: 1.426
Sampling Loss: 0.997
Sampling Loss: 0.946
Sampling Loss: 1.143
Sampling Loss: 1.027
Sampling Loss: 0.718
Sampling Loss: 1.075
Sampling Loss: 1.150
Sampling Loss: 0.656
Sampling Loss: 0.805
Sampling Loss: 1.292
Sampling Loss: 0.642
Sampling Loss: 0.795
Sampling Loss: 1.037
Sampling Loss: 1.116
Sampling Loss: 0.771
Sampling Loss: 0.981
Sampling Loss

In [7]:
if __name__ == '__main__':
#     t = TesterTimm()
#     t.process()

    t = TesterTimm()
    for i in range(12):
        global model_path
        model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep{i+1}_lr0.0003.pt'
        t.process()

Parameter: 102595923
Classes: 1000
Epochs: 1


100%|██████████| 98/98 [03:19<00:00,  1.67s/it]


Accuracy of 50000 test images: 73.38 %
Parameter: 102595923
Classes: 1000
Epochs: 2


100%|██████████| 98/98 [03:18<00:00,  1.72s/it]


Accuracy of 50000 test images: 74.65 %
Parameter: 102595923
Classes: 1000
Epochs: 3


100%|██████████| 98/98 [03:18<00:00,  1.56s/it]


Accuracy of 50000 test images: 75.85 %
Parameter: 102595923
Classes: 1000
Epochs: 4


100%|██████████| 98/98 [03:19<00:00,  1.76s/it]


Accuracy of 50000 test images: 76.29 %
Parameter: 102595923
Classes: 1000
Epochs: 5


  5%|▌         | 5/98 [00:13<04:24,  2.85s/it]

KeyboardInterrupt: 

In [None]:
# 프리 트레이닝
# 1스텝 = 9분
# 1에포크 = 3시간

# 파인 튜닝
# 1스텝 = 6분
# 1에포크 = 2시간

# baseline 이름 안적은거 수정해야함.