## Library

In [1]:
import timm
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm, tqdm_notebook
from torch.optim.lr_scheduler import CosineAnnealingLR

from vit_pooling import ViTPooling
from patch_aug import NegativePatchShuffle, NegativePatchRotate, PositivePatchShuffle, PositivePatchRotate


gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[0])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 8
NUM_WORKERS = 2
LEARNING_RATE = 1.25e-03
pre_model_path = './save/ViT/timm/ViT_timm_vit_base_patch16_224_in21k.pt'
fine_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augVanilla_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'
dynamic_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augVanilla_i2012_ep'

NUM_CLASSES = 1000

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Fine-tuning Class

In [4]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = [0]
        self.losses = [0]

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True).to(device)
        self.model.num_classes = NUM_CLASSES
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(pre_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss += loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss/1000)
                    self.save_model()
                    saving_loss = 0.0
            scheduler.step()
        print('****** Finished Fine-tuning ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [5]:
if __name__ == '__main__':
    trainer = FineTunner()
    trainer.process(load=True)

Parameter: 102595923
Classes: 1000
Parameter: 102595923
Classes: 1000
Epoch: 0
****** Reset epochs and losses ******


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 0, Batch   100] loss: 7.956
[Epoch 0, Batch   200] loss: 7.176
[Epoch 0, Batch   300] loss: 6.822
[Epoch 0, Batch   400] loss: 6.367
[Epoch 0, Batch   500] loss: 5.320
[Epoch 0, Batch   600] loss: 3.924
[Epoch 0, Batch   700] loss: 2.889
[Epoch 0, Batch   800] loss: 2.269
[Epoch 0, Batch   900] loss: 2.055
[Epoch 0, Batch  1000] loss: 2.069
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  1100] loss: 2.008
[Epoch 0, Batch  1200] loss: 1.774
[Epoch 0, Batch  1300] loss: 1.745
[Epoch 0, Batch  1400] loss: 1.711
[Epoch 0, Batch  1500] loss: 1.670
[Epoch 0, Batch  1600] loss: 1.637
[Epoch 0, Batch  1700] loss: 1.617
[Epoch 0, Batch  1800] loss: 1.596
[Epoch 0, Batch  1900] loss: 1.578
[Epoch 0, Batch  2000] loss: 1.529
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  2100] loss: 1.534
[Epoch 0, Batch  2200] loss: 1.484
[Epoch 0, Batch  2300] loss: 1.495
[Epoch 0, Batch  2400] loss: 1.460
[Epoch 0, Batch  2500] loss: 1.403
[Epoch 0, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 1, Batch   100] loss: 0.977
[Epoch 1, Batch   200] loss: 0.975
[Epoch 1, Batch   300] loss: 1.019
[Epoch 1, Batch   400] loss: 0.950
[Epoch 1, Batch   500] loss: 0.993
[Epoch 1, Batch   600] loss: 0.993
[Epoch 1, Batch   700] loss: 0.965
[Epoch 1, Batch   800] loss: 0.959
[Epoch 1, Batch   900] loss: 0.966
[Epoch 1, Batch  1000] loss: 1.022
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  1100] loss: 0.989
[Epoch 1, Batch  1200] loss: 1.031
[Epoch 1, Batch  1300] loss: 1.002
[Epoch 1, Batch  1400] loss: 0.950
[Epoch 1, Batch  1500] loss: 1.006
[Epoch 1, Batch  1600] loss: 0.989
[Epoch 1, Batch  1700] loss: 1.031
[Epoch 1, Batch  1800] loss: 1.016
[Epoch 1, Batch  1900] loss: 1.008
[Epoch 1, Batch  2000] loss: 0.990
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  2100] loss: 1.016
[Epoch 1, Batch  2200] loss: 0.955
[Epoch 1, Batch  2300] loss: 0.917
[Epoch 1, Batch  2400] loss: 0.973
[Epoch 1, Batch  2500] loss: 0.940
[Epoch 1, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 2, Batch   100] loss: 0.839
[Epoch 2, Batch   200] loss: 0.858
[Epoch 2, Batch   300] loss: 0.873
[Epoch 2, Batch   400] loss: 0.841
[Epoch 2, Batch   500] loss: 0.862
[Epoch 2, Batch   600] loss: 0.822
[Epoch 2, Batch   700] loss: 0.862
[Epoch 2, Batch   800] loss: 0.815
[Epoch 2, Batch   900] loss: 0.823
[Epoch 2, Batch  1000] loss: 0.825
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  1100] loss: 0.839
[Epoch 2, Batch  1200] loss: 0.853
[Epoch 2, Batch  1300] loss: 0.846
[Epoch 2, Batch  1400] loss: 0.819
[Epoch 2, Batch  1500] loss: 0.838
[Epoch 2, Batch  1600] loss: 0.845
[Epoch 2, Batch  1700] loss: 0.830
[Epoch 2, Batch  1800] loss: 0.905
[Epoch 2, Batch  1900] loss: 0.889
[Epoch 2, Batch  2000] loss: 0.856
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  2100] loss: 0.872
[Epoch 2, Batch  2200] loss: 0.824
[Epoch 2, Batch  2300] loss: 0.825
[Epoch 2, Batch  2400] loss: 0.829
[Epoch 2, Batch  2500] loss: 0.823
[Epoch 2, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 3, Batch   100] loss: 0.733
[Epoch 3, Batch   200] loss: 0.737
[Epoch 3, Batch   300] loss: 0.761
[Epoch 3, Batch   400] loss: 0.707
[Epoch 3, Batch   500] loss: 0.715
[Epoch 3, Batch   600] loss: 0.748
[Epoch 3, Batch   700] loss: 0.754
[Epoch 3, Batch   800] loss: 0.762
[Epoch 3, Batch   900] loss: 0.746
[Epoch 3, Batch  1000] loss: 0.763
****** Model checkpoint saved at epochs 4 ******
[Epoch 3, Batch  1100] loss: 0.742
[Epoch 3, Batch  1200] loss: 0.759
[Epoch 3, Batch  1300] loss: 0.729
[Epoch 3, Batch  1400] loss: 0.725
[Epoch 3, Batch  1500] loss: 0.750
[Epoch 3, Batch  1600] loss: 0.739
[Epoch 3, Batch  1700] loss: 0.747
[Epoch 3, Batch  1800] loss: 0.740
[Epoch 3, Batch  1900] loss: 0.734
[Epoch 3, Batch  2000] loss: 0.730
****** Model checkpoint saved at epochs 4 ******
[Epoch 3, Batch  2100] loss: 0.753
[Epoch 3, Batch  2200] loss: 0.753
[Epoch 3, Batch  2300] loss: 0.729
[Epoch 3, Batch  2400] loss: 0.700
[Epoch 3, Batch  2500] loss: 0.750
[Epoch 3, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 4, Batch   100] loss: 0.663
[Epoch 4, Batch   200] loss: 0.637
[Epoch 4, Batch   300] loss: 0.679
[Epoch 4, Batch   400] loss: 0.673
[Epoch 4, Batch   500] loss: 0.633
[Epoch 4, Batch   600] loss: 0.660
[Epoch 4, Batch   700] loss: 0.639
[Epoch 4, Batch   800] loss: 0.651
[Epoch 4, Batch   900] loss: 0.644
[Epoch 4, Batch  1000] loss: 0.649
****** Model checkpoint saved at epochs 5 ******
[Epoch 4, Batch  1100] loss: 0.636
[Epoch 4, Batch  1200] loss: 0.642
[Epoch 4, Batch  1300] loss: 0.683
[Epoch 4, Batch  1400] loss: 0.628
[Epoch 4, Batch  1500] loss: 0.633
[Epoch 4, Batch  1600] loss: 0.651
[Epoch 4, Batch  1700] loss: 0.673
[Epoch 4, Batch  1800] loss: 0.668
[Epoch 4, Batch  1900] loss: 0.649
[Epoch 4, Batch  2000] loss: 0.652
****** Model checkpoint saved at epochs 5 ******
[Epoch 4, Batch  2100] loss: 0.632
[Epoch 4, Batch  2200] loss: 0.632
[Epoch 4, Batch  2300] loss: 0.610
[Epoch 4, Batch  2400] loss: 0.636
[Epoch 4, Batch  2500] loss: 0.669
[Epoch 4, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 5, Batch   100] loss: 0.565
[Epoch 5, Batch   200] loss: 0.606
[Epoch 5, Batch   300] loss: 0.585
[Epoch 5, Batch   400] loss: 0.564
[Epoch 5, Batch   500] loss: 0.568
[Epoch 5, Batch   600] loss: 0.605
[Epoch 5, Batch   700] loss: 0.592
[Epoch 5, Batch   800] loss: 0.588
[Epoch 5, Batch   900] loss: 0.558
[Epoch 5, Batch  1000] loss: 0.551
****** Model checkpoint saved at epochs 6 ******
[Epoch 5, Batch  1100] loss: 0.589
[Epoch 5, Batch  1200] loss: 0.586
[Epoch 5, Batch  1300] loss: 0.553
[Epoch 5, Batch  1400] loss: 0.561
[Epoch 5, Batch  1500] loss: 0.557
[Epoch 5, Batch  1600] loss: 0.537
[Epoch 5, Batch  1700] loss: 0.591
[Epoch 5, Batch  1800] loss: 0.589
[Epoch 5, Batch  1900] loss: 0.561
[Epoch 5, Batch  2000] loss: 0.570
****** Model checkpoint saved at epochs 6 ******
[Epoch 5, Batch  2100] loss: 0.570
[Epoch 5, Batch  2200] loss: 0.579
[Epoch 5, Batch  2300] loss: 0.577
[Epoch 5, Batch  2400] loss: 0.583
[Epoch 5, Batch  2500] loss: 0.563
[Epoch 5, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 6, Batch   100] loss: 0.505
[Epoch 6, Batch   200] loss: 0.527
[Epoch 6, Batch   300] loss: 0.513
[Epoch 6, Batch   400] loss: 0.543
[Epoch 6, Batch   500] loss: 0.525
[Epoch 6, Batch   600] loss: 0.498
[Epoch 6, Batch   700] loss: 0.505
[Epoch 6, Batch   800] loss: 0.490
[Epoch 6, Batch   900] loss: 0.516
[Epoch 6, Batch  1000] loss: 0.508
****** Model checkpoint saved at epochs 7 ******
[Epoch 6, Batch  1100] loss: 0.507
[Epoch 6, Batch  1200] loss: 0.545
[Epoch 6, Batch  1300] loss: 0.513
[Epoch 6, Batch  1400] loss: 0.506
[Epoch 6, Batch  1500] loss: 0.547
[Epoch 6, Batch  1600] loss: 0.527
[Epoch 6, Batch  1700] loss: 0.529
[Epoch 6, Batch  1800] loss: 0.518
[Epoch 6, Batch  1900] loss: 0.494
[Epoch 6, Batch  2000] loss: 0.513
****** Model checkpoint saved at epochs 7 ******
[Epoch 6, Batch  2100] loss: 0.526
[Epoch 6, Batch  2200] loss: 0.503
[Epoch 6, Batch  2300] loss: 0.481
[Epoch 6, Batch  2400] loss: 0.494
[Epoch 6, Batch  2500] loss: 0.501
[Epoch 6, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 7, Batch   100] loss: 0.486
[Epoch 7, Batch   200] loss: 0.479
[Epoch 7, Batch   300] loss: 0.469
[Epoch 7, Batch   400] loss: 0.483
[Epoch 7, Batch   500] loss: 0.465
[Epoch 7, Batch   600] loss: 0.487
[Epoch 7, Batch   700] loss: 0.495
[Epoch 7, Batch   800] loss: 0.486
[Epoch 7, Batch   900] loss: 0.457
[Epoch 7, Batch  1000] loss: 0.453
****** Model checkpoint saved at epochs 8 ******
[Epoch 7, Batch  1100] loss: 0.472
[Epoch 7, Batch  1200] loss: 0.453
[Epoch 7, Batch  1300] loss: 0.446
[Epoch 7, Batch  1400] loss: 0.459
[Epoch 7, Batch  1500] loss: 0.464
[Epoch 7, Batch  1600] loss: 0.446
[Epoch 7, Batch  1700] loss: 0.458
[Epoch 7, Batch  1800] loss: 0.480
[Epoch 7, Batch  1900] loss: 0.463
[Epoch 7, Batch  2000] loss: 0.489
****** Model checkpoint saved at epochs 8 ******
[Epoch 7, Batch  2100] loss: 0.485
[Epoch 7, Batch  2200] loss: 0.491
[Epoch 7, Batch  2300] loss: 0.445
[Epoch 7, Batch  2400] loss: 0.482
[Epoch 7, Batch  2500] loss: 0.477
[Epoch 7, Batch  2600] loss