# timm ViT 구현 Cifar-100 fine-tuning

## timm 설치

pip install timm

## Library

In [None]:
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm.notebook import tqdm

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[3])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

## Hyper parameter

In [2]:
model_path = './save/timm_ViT_Cifar100.pt'
device = gpu
BATCH_SIZE = 32
NUM_EPOCHS = 500
NUM_WORKERS = 2
LEARNING_RATE = 0.01
NUM_CLASSES = 100

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.CIFAR100(root='./data/', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
testset = datasets.CIFAR100(root='./data/', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

Files already downloaded and verified
Files already downloaded and verified


## Class 선언

In [4]:
class ViTCifar100Model(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epochs = []
        self.losses = []

    def process(self):
        self.build_model()
        self.train_model()
        self.eval_model()

    def build_model(self):
        self.model = timm.models.vit_base_patch16_224(pretrained=True).to(device)
        self.model.num_classes = NUM_CLASSES
        # self.model = timm.models.vit_large_patch16_224(pretrained=True).to(device)
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')

    def train_model(self):
        model = self.model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0, last_epoch=-1)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            for i, data in tqdm(enumerate(trainloader, 0), total=len(trainloader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
#                 if i % 10 == 0:
#                     print(f'[Epoch {epoch + 1}, Batch {i + 1:5d}] loss: {loss / 100:.3f}')
            if epoch % 1 == 0:
                self.epochs.append(epoch + 1)
                self.model = model
                self.optimizer = optimizer
                self.scheduler = scheduler
                self.losses.append(running_loss)
                self.save_model()
            scheduler.step()
        print('****** Finished Training ******')

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, model_path)
        print(f"****** Model checkpoint saved at epoch {self.epochs[-1]} ******")

    def eval_model(self):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'Accuracy {len(testset)} test images: {100 * correct / total:.2f} %')

## Method 실행

In [None]:
if __name__ == '__main__':
    vit = ViTCifar100Model()
    vit.process()

Parameter: 86567656
Classes: 100


  0%|          | 0/1563 [00:00<?, ?it/s]

****** Model checkpoint saved at epoch 1 ******


  0%|          | 0/1563 [00:00<?, ?it/s]

## Loss check

In [None]:
print(vit.losses)