# timm ViT 구현 Cifar-100 fine-tuning

## timm 설치

pip install timm

## Library

In [1]:
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[3])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['TITAN Xp', 'TITAN Xp', 'TITAN Xp', 'TITAN Xp']


## Hyper parameter

In [2]:
model_path = './save/timm_ViT_Cifar100.pt'
device = gpu
BATCH_SIZE = 32
NUM_EPOCHS = 10000
NUM_WORKERS = 2
LEARNING_RATE = 0.001

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.CIFAR100(root='./data/', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
testset = datasets.CIFAR100(root='./data/', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

Files already downloaded and verified
Files already downloaded and verified


## Class 선언 및 실행

In [None]:
class ViTCifar100Model(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.epoch = 0

    def process(self):
        self.build_model()
        self.train_model()
        self.eval_model()

    def build_model(self):
        self.model = timm.models.vit_base_patch16_224(pretrained=True).to(device)
        # self.model = timm.models.vit_large_patch16_224(pretrained=True).to(device)
        print(f'Parameter : {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')

    def train_model(self):
        model = self.model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0, last_epoch=-1)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            for i, data in tqdm(enumerate(trainloader, 0), total=len(trainloader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                if i % 100 == 0:
                    print(f'[Epoch {epoch + 1}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            if epoch % 10 == 0:
                self.epoch = epoch + 1
                self.model = model
                self.optimizer = optimizer
                self.scheduler = scheduler
                self.save_model()
            scheduler.step()
        print('****** Finished Training ******')

    def save_model(self):
        checkpoint = {
            'epoch': self.epoch,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        torch.save(checkpoint, model_path)
        print(f"****** Model checkpoint saved at epoch {self.epoch} ******")

    def eval_model(self):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'Accuracy {len(testset)} test images: {100 * correct / total:.2f} %')
        

if __name__ == '__main__':
    ViTCifar100Model().process()

Parameter : 86567656


  0%|                                          | 1/1563 [00:00<15:54,  1.64it/s]

[Epoch 1, Batch     1] loss: 0.113


  6%|██▌                                     | 101/1563 [00:46<11:24,  2.14it/s]

[Epoch 1, Batch   101] loss: 5.109


 13%|█████▏                                  | 201/1563 [01:33<10:39,  2.13it/s]

[Epoch 1, Batch   201] loss: 4.544


 19%|███████▋                                | 301/1563 [02:20<09:56,  2.12it/s]

[Epoch 1, Batch   301] loss: 4.404


 26%|██████████▎                             | 401/1563 [03:08<09:08,  2.12it/s]

[Epoch 1, Batch   401] loss: 4.296


 32%|████████████▊                           | 501/1563 [03:55<08:21,  2.12it/s]

[Epoch 1, Batch   501] loss: 4.179


 38%|███████████████▍                        | 601/1563 [04:42<07:34,  2.12it/s]

[Epoch 1, Batch   601] loss: 4.055


 45%|█████████████████▉                      | 701/1563 [05:29<06:48,  2.11it/s]

[Epoch 1, Batch   701] loss: 4.067


 51%|████████████████████▍                   | 801/1563 [06:16<05:59,  2.12it/s]

[Epoch 1, Batch   801] loss: 3.978


 58%|███████████████████████                 | 901/1563 [07:04<05:12,  2.12it/s]

[Epoch 1, Batch   901] loss: 3.969


 64%|████████████████████████▉              | 1001/1563 [07:51<04:25,  2.12it/s]

[Epoch 1, Batch  1001] loss: 4.021


 70%|███████████████████████████▍           | 1101/1563 [08:38<03:38,  2.11it/s]

[Epoch 1, Batch  1101] loss: 4.030


 77%|█████████████████████████████▉         | 1201/1563 [09:25<02:51,  2.12it/s]

[Epoch 1, Batch  1201] loss: 3.976


 83%|████████████████████████████████▍      | 1301/1563 [10:13<02:03,  2.11it/s]

[Epoch 1, Batch  1301] loss: 3.928


 90%|██████████████████████████████████▉    | 1401/1563 [11:00<01:16,  2.12it/s]

[Epoch 1, Batch  1401] loss: 3.920


 96%|█████████████████████████████████████▍ | 1501/1563 [11:47<00:29,  2.10it/s]

[Epoch 1, Batch  1501] loss: 3.968


100%|███████████████████████████████████████| 1563/1563 [12:16<00:00,  2.12it/s]


****** Model checkpoint saved at epoch 1 ******


  0%|                                          | 1/1563 [00:00<14:38,  1.78it/s]

[Epoch 2, Batch     1] loss: 0.038


  6%|██▌                                     | 101/1563 [00:47<11:31,  2.11it/s]

[Epoch 2, Batch   101] loss: 3.882


 13%|█████▏                                  | 201/1563 [01:34<10:44,  2.11it/s]

[Epoch 2, Batch   201] loss: 3.892


 19%|███████▋                                | 301/1563 [02:22<09:56,  2.12it/s]

[Epoch 2, Batch   301] loss: 3.889


 26%|██████████▎                             | 401/1563 [03:09<09:09,  2.12it/s]

[Epoch 2, Batch   401] loss: 3.903


 32%|████████████▊                           | 501/1563 [03:56<08:21,  2.12it/s]

[Epoch 2, Batch   501] loss: 3.941


 38%|███████████████▍                        | 601/1563 [04:44<07:34,  2.11it/s]

[Epoch 2, Batch   601] loss: 3.870


 45%|█████████████████▉                      | 701/1563 [05:31<06:47,  2.12it/s]

[Epoch 2, Batch   701] loss: 3.868


 51%|████████████████████▍                   | 801/1563 [06:18<06:00,  2.11it/s]

[Epoch 2, Batch   801] loss: 3.868


 58%|███████████████████████                 | 901/1563 [07:05<05:13,  2.11it/s]

[Epoch 2, Batch   901] loss: 3.886


 64%|████████████████████████▉              | 1001/1563 [07:53<04:25,  2.12it/s]

[Epoch 2, Batch  1001] loss: 3.833


 70%|███████████████████████████▍           | 1101/1563 [08:40<03:38,  2.11it/s]

[Epoch 2, Batch  1101] loss: 3.862


 77%|█████████████████████████████▉         | 1201/1563 [09:27<02:51,  2.11it/s]

[Epoch 2, Batch  1201] loss: 3.870


 83%|████████████████████████████████▍      | 1301/1563 [10:14<02:04,  2.11it/s]

[Epoch 2, Batch  1301] loss: 3.820


 90%|██████████████████████████████████▉    | 1401/1563 [11:02<01:16,  2.12it/s]

[Epoch 2, Batch  1401] loss: 3.855


 96%|█████████████████████████████████████▍ | 1501/1563 [11:49<00:29,  2.11it/s]

[Epoch 2, Batch  1501] loss: 3.802


100%|███████████████████████████████████████| 1563/1563 [12:18<00:00,  2.12it/s]
  0%|                                          | 1/1563 [00:00<14:48,  1.76it/s]

[Epoch 3, Batch     1] loss: 0.040


  6%|██▌                                     | 101/1563 [00:47<11:30,  2.12it/s]

[Epoch 3, Batch   101] loss: 3.811


 13%|█████▏                                  | 201/1563 [01:35<10:44,  2.11it/s]

[Epoch 3, Batch   201] loss: 3.789


 19%|███████▋                                | 301/1563 [02:22<09:56,  2.12it/s]

[Epoch 3, Batch   301] loss: 3.839


 26%|██████████▎                             | 401/1563 [03:09<09:09,  2.11it/s]

[Epoch 3, Batch   401] loss: 3.830


 32%|████████████▊                           | 501/1563 [03:56<08:22,  2.11it/s]

[Epoch 3, Batch   501] loss: 3.823


 38%|███████████████▍                        | 601/1563 [04:44<07:34,  2.12it/s]

[Epoch 3, Batch   601] loss: 3.801


 45%|█████████████████▉                      | 701/1563 [05:31<06:48,  2.11it/s]

[Epoch 3, Batch   701] loss: 3.794


 51%|████████████████████▍                   | 801/1563 [06:18<06:00,  2.11it/s]

[Epoch 3, Batch   801] loss: 3.843


 58%|███████████████████████                 | 901/1563 [07:06<05:13,  2.11it/s]

[Epoch 3, Batch   901] loss: 3.835


 64%|████████████████████████▉              | 1001/1563 [07:53<04:26,  2.11it/s]

[Epoch 3, Batch  1001] loss: 3.797


 70%|███████████████████████████▍           | 1101/1563 [08:40<03:38,  2.11it/s]

[Epoch 3, Batch  1101] loss: 3.804


 77%|█████████████████████████████▉         | 1201/1563 [09:27<02:51,  2.11it/s]

[Epoch 3, Batch  1201] loss: 3.787


 83%|████████████████████████████████▍      | 1301/1563 [10:15<02:03,  2.12it/s]

[Epoch 3, Batch  1301] loss: 3.780


 90%|██████████████████████████████████▉    | 1401/1563 [11:02<01:16,  2.11it/s]

[Epoch 3, Batch  1401] loss: 3.765


 96%|█████████████████████████████████████▍ | 1501/1563 [11:49<00:29,  2.11it/s]

[Epoch 3, Batch  1501] loss: 3.785


100%|███████████████████████████████████████| 1563/1563 [12:18<00:00,  2.12it/s]
  0%|                                          | 1/1563 [00:00<14:43,  1.77it/s]

[Epoch 4, Batch     1] loss: 0.038


  6%|██▌                                     | 101/1563 [00:47<11:32,  2.11it/s]

[Epoch 4, Batch   101] loss: 3.790


 13%|█████▏                                  | 201/1563 [01:35<10:44,  2.11it/s]

[Epoch 4, Batch   201] loss: 3.766


 19%|███████▋                                | 301/1563 [02:22<09:56,  2.12it/s]

[Epoch 4, Batch   301] loss: 3.812


 26%|██████████▎                             | 401/1563 [03:09<09:10,  2.11it/s]

[Epoch 4, Batch   401] loss: 3.774


 32%|████████████▊                           | 501/1563 [03:57<08:23,  2.11it/s]

[Epoch 4, Batch   501] loss: 3.740


 38%|███████████████▍                        | 601/1563 [04:44<07:35,  2.11it/s]

[Epoch 4, Batch   601] loss: 3.780


 45%|█████████████████▉                      | 701/1563 [05:31<06:48,  2.11it/s]

[Epoch 4, Batch   701] loss: 3.730


 51%|████████████████████▍                   | 801/1563 [06:19<06:02,  2.10it/s]

[Epoch 4, Batch   801] loss: 3.754


 58%|███████████████████████                 | 901/1563 [07:06<05:13,  2.11it/s]

[Epoch 4, Batch   901] loss: 3.782


 64%|████████████████████████▉              | 1001/1563 [07:54<04:26,  2.11it/s]

[Epoch 4, Batch  1001] loss: 3.766


 70%|███████████████████████████▍           | 1101/1563 [08:41<03:38,  2.11it/s]

[Epoch 4, Batch  1101] loss: 3.749


 73%|████████████████████████████▍          | 1140/1563 [08:59<03:19,  2.12it/s]