## Library

In [1]:
import timm
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from vit_pooling import ViTPooling

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[1])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 8
NUM_WORKERS = 2
LEARNING_RATE = 0.0003
pre_model_path = './save/ViT_timm_vit_base_patch16_224_in21k.pt'
fine_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augVanilla_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'

NUM_CLASSES = 1000

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Fine-tuning Class

In [4]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = [0]
        self.losses = [0]

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True).to(device)
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(pre_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss = loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss)
                    self.save_model()
        print('****** Finished Fine-tuning ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = FineTunner()
    trainer.process(load=True)

Parameter: 102595923
Classes: 1000
Parameter: 102595923
Classes: 1000
Epoch: 0
****** Reset epochs and losses ******


  0%|          | 100/20019 [01:04<3:06:20,  1.78it/s]

[Epoch 0, Batch   100] loss: 8.389


  1%|          | 200/20019 [02:07<2:51:13,  1.93it/s] 

[Epoch 0, Batch   200] loss: 7.361


  1%|▏         | 300/20019 [03:11<3:35:45,  1.52it/s]

[Epoch 0, Batch   300] loss: 7.183


  2%|▏         | 400/20019 [04:08<3:04:48,  1.77it/s]

[Epoch 0, Batch   400] loss: 7.067


  2%|▏         | 500/20019 [05:11<3:10:29,  1.71it/s] 

[Epoch 0, Batch   500] loss: 6.959


  3%|▎         | 600/20019 [06:08<3:12:26,  1.68it/s]

[Epoch 0, Batch   600] loss: 6.879


  3%|▎         | 700/20019 [07:04<3:13:56,  1.66it/s]

[Epoch 0, Batch   700] loss: 6.786


  4%|▍         | 800/20019 [08:08<6:20:32,  1.19s/it] 

[Epoch 0, Batch   800] loss: 6.673


  4%|▍         | 900/20019 [09:06<3:19:02,  1.60it/s]

[Epoch 0, Batch   900] loss: 6.565


  5%|▍         | 999/20019 [10:09<2:52:17,  1.84it/s] 

[Epoch 0, Batch  1000] loss: 6.398


  5%|▍         | 1000/20019 [10:12<6:41:16,  1.27s/it]

****** Model checkpoint saved at epochs 1 ******


  5%|▌         | 1100/20019 [11:17<3:10:30,  1.66it/s]

[Epoch 0, Batch  1100] loss: 6.183


  6%|▌         | 1200/20019 [12:20<3:36:44,  1.45it/s]

[Epoch 0, Batch  1200] loss: 5.911


  6%|▋         | 1300/20019 [13:24<3:08:33,  1.65it/s]

[Epoch 0, Batch  1300] loss: 5.567


  7%|▋         | 1400/20019 [14:21<2:38:35,  1.96it/s]

[Epoch 0, Batch  1400] loss: 5.130


  7%|▋         | 1500/20019 [15:25<2:43:12,  1.89it/s] 

[Epoch 0, Batch  1500] loss: 4.648


  8%|▊         | 1600/20019 [16:30<2:44:51,  1.86it/s]

[Epoch 0, Batch  1600] loss: 4.128


  8%|▊         | 1700/20019 [17:32<2:49:08,  1.81it/s]

[Epoch 0, Batch  1700] loss: 3.540


  9%|▉         | 1800/20019 [18:36<5:55:29,  1.17s/it] 

[Epoch 0, Batch  1800] loss: 2.998


  9%|▉         | 1900/20019 [19:37<2:51:04,  1.77it/s]

[Epoch 0, Batch  1900] loss: 2.597


 10%|▉         | 1999/20019 [20:36<3:01:12,  1.66it/s]

[Epoch 0, Batch  2000] loss: 2.286


 10%|▉         | 2000/20019 [20:38<5:47:03,  1.16s/it]

****** Model checkpoint saved at epochs 1 ******


 10%|█         | 2100/20019 [21:43<3:06:55,  1.60it/s]

[Epoch 0, Batch  2100] loss: 2.068


 11%|█         | 2200/20019 [22:45<3:02:04,  1.63it/s]

[Epoch 0, Batch  2200] loss: 1.849


 11%|█▏        | 2300/20019 [23:48<3:09:02,  1.56it/s]

[Epoch 0, Batch  2300] loss: 1.763


 12%|█▏        | 2400/20019 [24:50<2:36:07,  1.88it/s]

[Epoch 0, Batch  2400] loss: 1.657


 12%|█▏        | 2500/20019 [25:50<2:26:49,  1.99it/s]

[Epoch 0, Batch  2500] loss: 1.583


 13%|█▎        | 2600/20019 [26:47<2:28:47,  1.95it/s]

[Epoch 0, Batch  2600] loss: 1.569


 13%|█▎        | 2700/20019 [27:45<2:41:43,  1.78it/s]

[Epoch 0, Batch  2700] loss: 1.453


 14%|█▍        | 2800/20019 [28:42<2:29:39,  1.92it/s]

[Epoch 0, Batch  2800] loss: 1.421


 14%|█▍        | 2900/20019 [29:39<3:01:23,  1.57it/s]

[Epoch 0, Batch  2900] loss: 1.398


 15%|█▍        | 2999/20019 [30:36<2:35:43,  1.82it/s]

[Epoch 0, Batch  3000] loss: 1.378


 15%|█▍        | 3000/20019 [30:40<6:58:31,  1.48s/it]

****** Model checkpoint saved at epochs 1 ******


 15%|█▌        | 3100/20019 [31:50<3:31:15,  1.33it/s]

[Epoch 0, Batch  3100] loss: 1.368


 16%|█▌        | 3200/20019 [32:53<2:26:56,  1.91it/s]

[Epoch 0, Batch  3200] loss: 1.329


 16%|█▋        | 3300/20019 [33:54<2:24:31,  1.93it/s]

[Epoch 0, Batch  3300] loss: 1.288


 17%|█▋        | 3400/20019 [34:58<2:55:06,  1.58it/s] 

[Epoch 0, Batch  3400] loss: 1.329


 17%|█▋        | 3500/20019 [36:01<3:00:12,  1.53it/s]

[Epoch 0, Batch  3500] loss: 1.247


 18%|█▊        | 3600/20019 [37:04<2:57:12,  1.54it/s]

[Epoch 0, Batch  3600] loss: 1.251


 18%|█▊        | 3700/20019 [38:04<2:44:16,  1.66it/s]

[Epoch 0, Batch  3700] loss: 1.277


 19%|█▉        | 3800/20019 [39:01<2:48:06,  1.61it/s]

[Epoch 0, Batch  3800] loss: 1.226


 19%|█▉        | 3900/20019 [39:59<2:11:22,  2.04it/s]

[Epoch 0, Batch  3900] loss: 1.200


 20%|█▉        | 3999/20019 [40:56<2:15:11,  1.97it/s]

[Epoch 0, Batch  4000] loss: 1.222


 20%|█▉        | 4000/20019 [40:59<5:54:30,  1.33s/it]

****** Model checkpoint saved at epochs 1 ******


 20%|██        | 4100/20019 [42:05<2:46:53,  1.59it/s] 

[Epoch 0, Batch  4100] loss: 1.178


 21%|██        | 4200/20019 [43:03<2:47:50,  1.57it/s]

[Epoch 0, Batch  4200] loss: 1.191


 21%|██▏       | 4300/20019 [44:00<2:31:55,  1.72it/s]

[Epoch 0, Batch  4300] loss: 1.163


 22%|██▏       | 4400/20019 [44:58<2:35:58,  1.67it/s]

[Epoch 0, Batch  4400] loss: 1.200


 22%|██▏       | 4500/20019 [45:55<2:40:06,  1.62it/s]

[Epoch 0, Batch  4500] loss: 1.182


 23%|██▎       | 4600/20019 [46:54<2:40:05,  1.61it/s]

[Epoch 0, Batch  4600] loss: 1.176


 23%|██▎       | 4700/20019 [47:55<3:13:39,  1.32it/s]

[Epoch 0, Batch  4700] loss: 1.103


 24%|██▍       | 4800/20019 [49:04<2:47:56,  1.51it/s]

[Epoch 0, Batch  4800] loss: 1.102


 24%|██▍       | 4900/20019 [50:13<3:08:36,  1.34it/s]

[Epoch 0, Batch  4900] loss: 1.124


 25%|██▍       | 4999/20019 [51:19<2:31:46,  1.65it/s]

[Epoch 0, Batch  5000] loss: 1.107


 25%|██▍       | 5000/20019 [51:23<6:07:35,  1.47s/it]

****** Model checkpoint saved at epochs 1 ******


 25%|██▌       | 5100/20019 [52:43<2:26:25,  1.70it/s] 

[Epoch 0, Batch  5100] loss: 1.131


 26%|██▌       | 5200/20019 [53:55<3:08:15,  1.31it/s]

[Epoch 0, Batch  5200] loss: 1.162


 26%|██▋       | 5300/20019 [55:08<2:54:38,  1.40it/s] 

[Epoch 0, Batch  5300] loss: 1.113


 27%|██▋       | 5400/20019 [56:14<2:35:17,  1.57it/s]

[Epoch 0, Batch  5400] loss: 1.110


 27%|██▋       | 5500/20019 [57:18<2:18:17,  1.75it/s]

[Epoch 0, Batch  5500] loss: 1.141


 28%|██▊       | 5600/20019 [58:16<2:23:17,  1.68it/s]

[Epoch 0, Batch  5600] loss: 1.127


 28%|██▊       | 5700/20019 [59:21<2:18:51,  1.72it/s]

[Epoch 0, Batch  5700] loss: 1.068


 29%|██▉       | 5800/20019 [1:00:31<2:14:00,  1.77it/s] 

[Epoch 0, Batch  5800] loss: 1.092


 29%|██▉       | 5900/20019 [1:01:36<2:19:09,  1.69it/s]

[Epoch 0, Batch  5900] loss: 1.090


 30%|██▉       | 5999/20019 [1:02:41<2:41:59,  1.44it/s]

[Epoch 0, Batch  6000] loss: 1.104


 30%|██▉       | 6000/20019 [1:02:44<5:04:11,  1.30s/it]

****** Model checkpoint saved at epochs 1 ******


 30%|███       | 6100/20019 [1:03:57<2:34:46,  1.50it/s]

[Epoch 0, Batch  6100] loss: 1.087


 31%|███       | 6200/20019 [1:05:00<2:14:12,  1.72it/s]

[Epoch 0, Batch  6200] loss: 1.062


 31%|███▏      | 6300/20019 [1:06:03<2:32:50,  1.50it/s]

[Epoch 0, Batch  6300] loss: 1.065


 32%|███▏      | 6400/20019 [1:07:07<2:28:32,  1.53it/s]

[Epoch 0, Batch  6400] loss: 1.086


 32%|███▏      | 6500/20019 [1:08:14<2:09:44,  1.74it/s]

[Epoch 0, Batch  6500] loss: 1.078


 33%|███▎      | 6600/20019 [1:09:19<2:18:59,  1.61it/s]

[Epoch 0, Batch  6600] loss: 1.086


 33%|███▎      | 6700/20019 [1:10:22<2:15:49,  1.63it/s]

[Epoch 0, Batch  6700] loss: 1.069


 34%|███▍      | 6800/20019 [1:11:25<2:14:16,  1.64it/s]

[Epoch 0, Batch  6800] loss: 1.078


 34%|███▍      | 6900/20019 [1:12:41<2:20:24,  1.56it/s] 

[Epoch 0, Batch  6900] loss: 1.088


 35%|███▍      | 6999/20019 [1:13:47<2:08:05,  1.69it/s]

[Epoch 0, Batch  7000] loss: 1.049


 35%|███▍      | 7000/20019 [1:13:50<4:51:41,  1.34s/it]

****** Model checkpoint saved at epochs 1 ******


 35%|███▌      | 7100/20019 [1:15:04<2:51:59,  1.25it/s] 

[Epoch 0, Batch  7100] loss: 1.002


 36%|███▌      | 7200/20019 [1:16:08<2:11:32,  1.62it/s]

[Epoch 0, Batch  7200] loss: 1.060


 36%|███▋      | 7300/20019 [1:17:18<2:08:51,  1.65it/s] 

[Epoch 0, Batch  7300] loss: 1.076


 37%|███▋      | 7400/20019 [1:18:26<2:11:29,  1.60it/s]

[Epoch 0, Batch  7400] loss: 1.077


 37%|███▋      | 7500/20019 [1:19:34<1:56:20,  1.79it/s] 

[Epoch 0, Batch  7500] loss: 1.042


 38%|███▊      | 7592/20019 [1:20:32<2:00:27,  1.72it/s]