## Library

In [1]:
import timm
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.optim import AdamW, SGD
from torch import nn
from tqdm import tqdm, tqdm_notebook
from torch.optim.lr_scheduler import CosineAnnealingLR

from vit_pooling import ViTPooling
from positional_enhancement import PositionalEnhanceViT


gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    gpu = 'cuda:' + str(gpu_ids[2])  # GPU Number
else:
    gpu = "cuda" if torch.cuda.is_available() else "cpu"

[0, 1, 2, 3]
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


## Hyper parameter

In [2]:
device = gpu
BATCH_SIZE = 64
NUM_EPOCHS = 12
NUM_WORKERS = 2
LEARNING_RATE = 3e-04
pre_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_positionBaselineEp{NUM_EPOCHS}.pt'
fine_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_positionBaselineEp8_position_i2012_ep{NUM_EPOCHS}_lr{LEARNING_RATE}.pt'
pre_load_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_positionBaselineEp8.pt'
fine_load_model_path = f'./save/ViT_timm_vit_base_patch16_224_in21k_positionBaselineEp8.pt'

NUM_CLASSES = 1000
WEIGHT_DECAY = 0.1

## Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# pre_train_set = torchvision.datasets.ImageFolder('./data/ImageNet-21k', transform=transform_train)
# pre_train_loader = data.DataLoader(pre_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('../../YJ/ILSVRC2012/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Pre-training Class

In [4]:
class PreTrainer(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []

    def process(self, load=False):
        self.build_model(load)
        self.pretrain_model()
        self.save_model()

    def build_model(self, load):
        self.model = PositionalEnhanceViT(NUM_CLASSES).to(device)
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = AdamW(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(pre_load_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def pretrain_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss += loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch+1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss/1000)
                    self.save_model()
                    saving_loss = 0.0
            scheduler.step()
        print('****** Finished Pre-training ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, pre_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

## Fine-tuning Class

In [5]:
class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []

    def process(self, load=False):
        self.build_model(load)
        self.finetune_model()
        self.save_model()

    def build_model(self, load):
        self.model = PositionalEnhanceViT(NUM_CLASSES).to(device)
        for param in self.model.forward_vit.parameters():
            param.requires_grad = True
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {self.model.num_classes}')
        self.optimizer = SGD(self.model.parameters(), lr=0)
        if load:
            checkpoint = torch.load(fine_load_model_path)
            self.epochs = checkpoint['epochs']
            self.model.load_state_dict(checkpoint['model'])
            self.losses = checkpoint['losses']
            print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
            print(f'Classes: {self.model.num_classes}')
            print(f'Epoch: {self.epochs[-1]}')
            print(f'****** Reset epochs and losses ******')
            self.epochs = []
            self.losses = []

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm_notebook(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss += loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
                if i % 1000 == 999:
                    self.epochs.append(epoch + 1)
                    self.model = model
                    self.optimizer = optimizer
                    self.losses.append(saving_loss/1000)
                    self.save_model()
                    saving_loss = 0.0
            scheduler.step()
        print('****** Finished Fine-tuning ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
#         torch.save(checkpoint, dynamic_model_path+str(self.epochs[-1])+f'_lr{LEARNING_RATE}.pt')
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

In [None]:
if __name__ == '__main__':
    trainer = FineTunner()
    trainer.process(load=True)

Parameter: 107883659
Classes: 1000
Parameter: 107883659
Classes: 1000
Epoch: 8
****** Reset epochs and losses ******


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 0, Batch   100] loss: 1.033
[Epoch 0, Batch   200] loss: 1.022
[Epoch 0, Batch   300] loss: 1.048
[Epoch 0, Batch   400] loss: 1.048
[Epoch 0, Batch   500] loss: 1.059
[Epoch 0, Batch   600] loss: 1.027
[Epoch 0, Batch   700] loss: 1.053
[Epoch 0, Batch   800] loss: 1.047
[Epoch 0, Batch   900] loss: 1.069
[Epoch 0, Batch  1000] loss: 1.036
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  1100] loss: 1.040
[Epoch 0, Batch  1200] loss: 1.008
[Epoch 0, Batch  1300] loss: 1.012
[Epoch 0, Batch  1400] loss: 1.014
[Epoch 0, Batch  1500] loss: 1.023
[Epoch 0, Batch  1600] loss: 1.024
[Epoch 0, Batch  1700] loss: 1.027
[Epoch 0, Batch  1800] loss: 0.995
[Epoch 0, Batch  1900] loss: 1.019
[Epoch 0, Batch  2000] loss: 1.075
****** Model checkpoint saved at epochs 1 ******
[Epoch 0, Batch  2100] loss: 0.995
[Epoch 0, Batch  2200] loss: 0.984
[Epoch 0, Batch  2300] loss: 1.010
[Epoch 0, Batch  2400] loss: 1.034
[Epoch 0, Batch  2500] loss: 0.971
[Epoch 0, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 1, Batch   100] loss: 0.896
[Epoch 1, Batch   200] loss: 0.902
[Epoch 1, Batch   300] loss: 0.895
[Epoch 1, Batch   400] loss: 0.888
[Epoch 1, Batch   500] loss: 0.857
[Epoch 1, Batch   600] loss: 0.861
[Epoch 1, Batch   700] loss: 0.910
[Epoch 1, Batch   800] loss: 0.897
[Epoch 1, Batch   900] loss: 0.863
[Epoch 1, Batch  1000] loss: 0.912
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  1100] loss: 0.895
[Epoch 1, Batch  1200] loss: 0.898
[Epoch 1, Batch  1300] loss: 0.879
[Epoch 1, Batch  1400] loss: 0.881
[Epoch 1, Batch  1500] loss: 0.875
[Epoch 1, Batch  1600] loss: 0.874
[Epoch 1, Batch  1700] loss: 0.887
[Epoch 1, Batch  1800] loss: 0.912
[Epoch 1, Batch  1900] loss: 0.844
[Epoch 1, Batch  2000] loss: 0.850
****** Model checkpoint saved at epochs 2 ******
[Epoch 1, Batch  2100] loss: 0.843
[Epoch 1, Batch  2200] loss: 0.859
[Epoch 1, Batch  2300] loss: 0.866
[Epoch 1, Batch  2400] loss: 0.856
[Epoch 1, Batch  2500] loss: 0.855
[Epoch 1, Batch  2600] loss

  0%|          | 0/20019 [00:00<?, ?it/s]

[Epoch 2, Batch   100] loss: 0.833
[Epoch 2, Batch   200] loss: 0.807
[Epoch 2, Batch   300] loss: 0.800
[Epoch 2, Batch   400] loss: 0.795
[Epoch 2, Batch   500] loss: 0.826
[Epoch 2, Batch   600] loss: 0.818
[Epoch 2, Batch   700] loss: 0.794
[Epoch 2, Batch   800] loss: 0.840
[Epoch 2, Batch   900] loss: 0.827
[Epoch 2, Batch  1000] loss: 0.810
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  1100] loss: 0.814
[Epoch 2, Batch  1200] loss: 0.799
[Epoch 2, Batch  1300] loss: 0.799
[Epoch 2, Batch  1400] loss: 0.820
[Epoch 2, Batch  1500] loss: 0.770
[Epoch 2, Batch  1600] loss: 0.786
[Epoch 2, Batch  1700] loss: 0.829
[Epoch 2, Batch  1800] loss: 0.845
[Epoch 2, Batch  1900] loss: 0.806
[Epoch 2, Batch  2000] loss: 0.823
****** Model checkpoint saved at epochs 3 ******
[Epoch 2, Batch  2100] loss: 0.795
[Epoch 2, Batch  2200] loss: 0.834
[Epoch 2, Batch  2300] loss: 0.833
[Epoch 2, Batch  2400] loss: 0.859
[Epoch 2, Batch  2500] loss: 0.829
[Epoch 2, Batch  2600] loss

In [None]:
# 포지션 8까지만 학습 해봤는데 12까지도 학습 해보기
# 이거 안되면 헤드쪽 파라미터도 0으로 만들기