## Library

In [1]:
import timm
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm, tqdm_notebook
from torch.optim.lr_scheduler import CosineAnnealingLR

from patch_aug import NegativePatchShuffle, NegativePatchRotate

gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[1])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 8
NUM_WORKERS = 2
LEARNING_RATE = 1.25e-03
pre_model_path = './save/ViT/timm/ViT_timm_vit_base_patch16_224_in21k.pt'
fine_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'
dynamic_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_augNegative_i2012_ep'

NUM_CLASSES = 1000

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Fine-tuning Class

In [4]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = [0]
        self.losses = [0]

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True).to(device)
        self.model.num_classes = NUM_CLASSES
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(pre_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        aug = NegativePatchRotate(p=0.5)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                aug.roll_the_dice(len(inputs))
                inputs = aug.rotate(inputs)
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = aug.cal_loss(outputs, labels, criterion, device)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss += loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss/1000)
                    self.save_model()
                    saving_loss = 0.0
            scheduler.step()
        print('****** Finished Fine-tuning ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
        # torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [5]:
if __name__ == '__main__':
    trainer = FineTunner()
    trainer.process(load=True)

Parameter: 102595923
Classes: 1000
Parameter: 102595923
Classes: 1000
Epoch: 0
****** Reset epochs and losses ******


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 0, Batch   100] loss: 9.047
[Epoch 0, Batch   200] loss: 8.299
[Epoch 0, Batch   300] loss: 8.121
[Epoch 0, Batch   400] loss: 8.081
[Epoch 0, Batch   500] loss: 7.979
[Epoch 0, Batch   600] loss: 7.838
[Epoch 0, Batch   700] loss: 7.678
[Epoch 0, Batch   800] loss: 7.393
[Epoch 0, Batch   900] loss: 6.791
[Epoch 0, Batch  1000] loss: 5.880
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  1100] loss: 4.741
[Epoch 0, Batch  1200] loss: 4.167
[Epoch 0, Batch  1300] loss: 3.868
[Epoch 0, Batch  1400] loss: 3.413
[Epoch 0, Batch  1500] loss: 3.297
[Epoch 0, Batch  1600] loss: 3.173
[Epoch 0, Batch  1700] loss: 2.992
[Epoch 0, Batch  1800] loss: 2.976
[Epoch 0, Batch  1900] loss: 3.007
[Epoch 0, Batch  2000] loss: 2.976
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  2100] loss: 2.921
[Epoch 0, Batch  2200] loss: 2.906
[Epoch 0, Batch  2300] loss: 2.911
[Epoch 0, Batch  2400] loss: 2.767
[Epoch 0, Batch  2500] loss: 2.773
[Epoch 0, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 1, Batch   100] loss: 2.088
[Epoch 1, Batch   200] loss: 2.104
[Epoch 1, Batch   300] loss: 2.136
[Epoch 1, Batch   400] loss: 2.062
[Epoch 1, Batch   500] loss: 2.111
[Epoch 1, Batch   600] loss: 2.110
[Epoch 1, Batch   700] loss: 2.120
[Epoch 1, Batch   800] loss: 2.053
[Epoch 1, Batch   900] loss: 2.090
[Epoch 1, Batch  1000] loss: 2.060
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  1100] loss: 2.095
[Epoch 1, Batch  1200] loss: 2.133
[Epoch 1, Batch  1300] loss: 2.106
[Epoch 1, Batch  1400] loss: 2.043
[Epoch 1, Batch  1500] loss: 2.080
[Epoch 1, Batch  1600] loss: 2.063
[Epoch 1, Batch  1700] loss: 2.088
[Epoch 1, Batch  1800] loss: 2.118
[Epoch 1, Batch  1900] loss: 2.074
[Epoch 1, Batch  2000] loss: 2.087
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  2100] loss: 2.045
[Epoch 1, Batch  2200] loss: 2.052
[Epoch 1, Batch  2300] loss: 2.095
[Epoch 1, Batch  2400] loss: 2.066
[Epoch 1, Batch  2500] loss: 2.072
[Epoch 1, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 2, Batch   100] loss: 1.884
[Epoch 2, Batch   200] loss: 1.961
[Epoch 2, Batch   300] loss: 1.878
[Epoch 2, Batch   400] loss: 1.918
[Epoch 2, Batch   500] loss: 1.894
[Epoch 2, Batch   600] loss: 1.961
[Epoch 2, Batch   700] loss: 1.899
[Epoch 2, Batch   800] loss: 1.866
[Epoch 2, Batch   900] loss: 1.827
[Epoch 2, Batch  1000] loss: 1.879
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  1100] loss: 1.902
[Epoch 2, Batch  1200] loss: 1.896
[Epoch 2, Batch  1300] loss: 1.839
[Epoch 2, Batch  1400] loss: 1.896
[Epoch 2, Batch  1500] loss: 1.941
[Epoch 2, Batch  1600] loss: 1.941
[Epoch 2, Batch  1700] loss: 1.911
[Epoch 2, Batch  1800] loss: 1.943
[Epoch 2, Batch  1900] loss: 1.847
[Epoch 2, Batch  2000] loss: 1.884
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  2100] loss: 1.875
[Epoch 2, Batch  2200] loss: 1.843
[Epoch 2, Batch  2300] loss: 1.916
[Epoch 2, Batch  2400] loss: 1.999
[Epoch 2, Batch  2500] loss: 1.871
[Epoch 2, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 3, Batch   100] loss: 1.783
[Epoch 3, Batch   200] loss: 1.763
[Epoch 3, Batch   300] loss: 1.765
[Epoch 3, Batch   400] loss: 1.786
[Epoch 3, Batch   500] loss: 1.693
[Epoch 3, Batch   600] loss: 1.768
[Epoch 3, Batch   700] loss: 1.801
[Epoch 3, Batch   800] loss: 1.816
[Epoch 3, Batch   900] loss: 1.794
[Epoch 3, Batch  1000] loss: 1.764
****** Model checkpoint saved at epochs 4 ******
[Epoch 3, Batch  1100] loss: 1.700
[Epoch 3, Batch  1200] loss: 1.794
[Epoch 3, Batch  1300] loss: 1.710
[Epoch 3, Batch  1400] loss: 1.718
[Epoch 3, Batch  1500] loss: 1.773
[Epoch 3, Batch  1600] loss: 1.739
[Epoch 3, Batch  1700] loss: 1.756
[Epoch 3, Batch  1800] loss: 1.739
[Epoch 3, Batch  1900] loss: 1.803
[Epoch 3, Batch  2000] loss: 1.783
****** Model checkpoint saved at epochs 4 ******
[Epoch 3, Batch  2100] loss: 1.778
[Epoch 3, Batch  2200] loss: 1.772
[Epoch 3, Batch  2300] loss: 1.737
[Epoch 3, Batch  2400] loss: 1.756
[Epoch 3, Batch  2500] loss: 1.658
[Epoch 3, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 4, Batch   100] loss: 1.684
[Epoch 4, Batch   200] loss: 1.624
[Epoch 4, Batch   300] loss: 1.621
[Epoch 4, Batch   400] loss: 1.622
[Epoch 4, Batch   500] loss: 1.719
[Epoch 4, Batch   600] loss: 1.605
[Epoch 4, Batch   700] loss: 1.601
[Epoch 4, Batch   800] loss: 1.632
[Epoch 4, Batch   900] loss: 1.610
[Epoch 4, Batch  1000] loss: 1.679
****** Model checkpoint saved at epochs 5 ******
[Epoch 4, Batch  1100] loss: 1.579
[Epoch 4, Batch  1200] loss: 1.653
[Epoch 4, Batch  1300] loss: 1.637
[Epoch 4, Batch  1400] loss: 1.650
[Epoch 4, Batch  1500] loss: 1.593
[Epoch 4, Batch  1600] loss: 1.628
[Epoch 4, Batch  1700] loss: 1.640
[Epoch 4, Batch  1800] loss: 1.623
[Epoch 4, Batch  1900] loss: 1.649
[Epoch 4, Batch  2000] loss: 1.625
****** Model checkpoint saved at epochs 5 ******
[Epoch 4, Batch  2100] loss: 1.595
[Epoch 4, Batch  2200] loss: 1.640
[Epoch 4, Batch  2300] loss: 1.636
[Epoch 4, Batch  2400] loss: 1.619
[Epoch 4, Batch  2500] loss: 1.626
[Epoch 4, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 5, Batch   100] loss: 1.497
[Epoch 5, Batch   200] loss: 1.523
[Epoch 5, Batch   300] loss: 1.514
[Epoch 5, Batch   400] loss: 1.521
[Epoch 5, Batch   500] loss: 1.541
[Epoch 5, Batch   600] loss: 1.527
[Epoch 5, Batch   700] loss: 1.592
[Epoch 5, Batch   800] loss: 1.471
[Epoch 5, Batch   900] loss: 1.542
[Epoch 5, Batch  1000] loss: 1.507
****** Model checkpoint saved at epochs 6 ******
[Epoch 5, Batch  1100] loss: 1.576
[Epoch 5, Batch  1200] loss: 1.511
[Epoch 5, Batch  1300] loss: 1.536
[Epoch 5, Batch  1400] loss: 1.534
[Epoch 5, Batch  1500] loss: 1.497
[Epoch 5, Batch  1600] loss: 1.509
[Epoch 5, Batch  1700] loss: 1.491
[Epoch 5, Batch  1800] loss: 1.546
[Epoch 5, Batch  1900] loss: 1.545
[Epoch 5, Batch  2000] loss: 1.560
****** Model checkpoint saved at epochs 6 ******
[Epoch 5, Batch  2100] loss: 1.558
[Epoch 5, Batch  2200] loss: 1.528
[Epoch 5, Batch  2300] loss: 1.508
[Epoch 5, Batch  2400] loss: 1.548
[Epoch 5, Batch  2500] loss: 1.500
[Epoch 5, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 6, Batch   100] loss: 1.455
[Epoch 6, Batch   200] loss: 1.477
[Epoch 6, Batch   300] loss: 1.471
[Epoch 6, Batch   400] loss: 1.416
[Epoch 6, Batch   500] loss: 1.400
[Epoch 6, Batch   600] loss: 1.484
[Epoch 6, Batch   700] loss: 1.495
[Epoch 6, Batch   800] loss: 1.483
[Epoch 6, Batch   900] loss: 1.451
[Epoch 6, Batch  1000] loss: 1.495
****** Model checkpoint saved at epochs 7 ******
[Epoch 6, Batch  1100] loss: 1.431
[Epoch 6, Batch  1200] loss: 1.466
[Epoch 6, Batch  1300] loss: 1.399
[Epoch 6, Batch  1400] loss: 1.461
[Epoch 6, Batch  1500] loss: 1.452
[Epoch 6, Batch  1600] loss: 1.500
[Epoch 6, Batch  1700] loss: 1.414
[Epoch 6, Batch  1800] loss: 1.370
[Epoch 6, Batch  1900] loss: 1.484
[Epoch 6, Batch  2000] loss: 1.507
****** Model checkpoint saved at epochs 7 ******
[Epoch 6, Batch  2100] loss: 1.419
[Epoch 6, Batch  2200] loss: 1.410
[Epoch 6, Batch  2300] loss: 1.442
[Epoch 6, Batch  2400] loss: 1.422
[Epoch 6, Batch  2500] loss: 1.436
[Epoch 6, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 7, Batch   100] loss: 1.376
[Epoch 7, Batch   200] loss: 1.362
[Epoch 7, Batch   300] loss: 1.402
[Epoch 7, Batch   400] loss: 1.381
[Epoch 7, Batch   500] loss: 1.415
[Epoch 7, Batch   600] loss: 1.385
[Epoch 7, Batch   700] loss: 1.407
[Epoch 7, Batch   800] loss: 1.409
[Epoch 7, Batch   900] loss: 1.343
[Epoch 7, Batch  1000] loss: 1.390
****** Model checkpoint saved at epochs 8 ******
[Epoch 7, Batch  1100] loss: 1.374
[Epoch 7, Batch  1200] loss: 1.396
[Epoch 7, Batch  1300] loss: 1.419
[Epoch 7, Batch  1400] loss: 1.451
[Epoch 7, Batch  1500] loss: 1.432
[Epoch 7, Batch  1600] loss: 1.380
[Epoch 7, Batch  1700] loss: 1.395
[Epoch 7, Batch  1800] loss: 1.362
[Epoch 7, Batch  1900] loss: 1.363
[Epoch 7, Batch  2000] loss: 1.388
****** Model checkpoint saved at epochs 8 ******
[Epoch 7, Batch  2100] loss: 1.423
[Epoch 7, Batch  2200] loss: 1.412
[Epoch 7, Batch  2300] loss: 1.386
[Epoch 7, Batch  2400] loss: 1.402
[Epoch 7, Batch  2500] loss: 1.394
[Epoch 7, Batch  2600] loss

In [6]:
# 벨리데이션 10% 쪼개기 / 얼리스탑 구현
# positive도 더 길게 돌려보기.
# 다음으로 MAE 코드 보면 될듯.